diff --git a/hkmc2/shared/src/main/scala/hkmc2/Uid.scala b/hkmc2/shared/src/main/scala/hkmc2/Uid.scala index 235b375c3..1b9ea27b0 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/Uid.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/Uid.scala @@ -13,6 +13,7 @@ object Uid: curUid def reset = curUid = -1 object Symbol extends Handler[semantics.Symbol] + object StratVar extends Handler[codegen.StratVar] extension [T] (x: Uid[T]) def <=(rhs: Uid[T]) = x <= rhs diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala index 3e609e6d5..f1b2540e7 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala @@ -406,6 +406,11 @@ enum Case: sealed trait TrivialResult extends Result +object Result: + opaque type ResultId = Int + private def ResultId(v: Int): ResultId = v + + sealed abstract class Result extends AutoLocated: protected def children: List[Located] = this match @@ -459,6 +464,15 @@ sealed abstract class Result extends AutoLocated: case DynSelect(qual, fld, arrayIdx) => qual.freeVarsLLIR ++ fld.freeVarsLLIR case Value.Rcd(args) => args.flatMap(arg => arg.idx.fold(Set.empty)(_.freeVarsLLIR) ++ arg.value.freeVarsLLIR).toSet + // for deforestation + def uid(using d: Deforest) = + import Result.* + val uidValue = ResultId(System.identityHashCode(this)) + d.resultIdToResult.updateWith(uidValue): + case N => S(this) + case S(r) => assert(this is r); S(this) + uidValue + // type Local = LocalSymbol type Local = Symbol diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Deforestation.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Deforestation.scala new file mode 100644 index 000000000..017c5233f --- /dev/null +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Deforestation.scala @@ -0,0 +1,1141 @@ +package hkmc2 +package codegen + +import semantics.* +import semantics.Elaborator.State +import syntax.{Literal, Tree} +import utils.* +import mlscript.utils.*, shorthands.* +import scala.collection.mutable +import scala.collection.mutable.LinkedHashMap +import Result.ResultId + +type StratVar +type StratVarId = Uid[StratVar] +type ClsOrModSymbol = ClassLikeSymbol + +sealed abstract class ProdStrat + +sealed abstract class ConsStrat + +class StratVarState(val uid: StratVarId, val name: Str = ""): + lazy val asProdStrat = ProdVar(this) + lazy val asConsStrat = ConsVar(this) + + override def toString(): String = s"${if name.isEmpty() then "var" else name}@${uid}" + +object StratVarState: + def freshVar(nme: String = "")(using vuid: Uid.StratVar.State) = + val newId = vuid.nextUid + val s = StratVarState(newId, nme) + val p = s.asProdStrat + val c = s.asConsStrat + p -> c + + +extension (i: ResultId) + def getResult(using d: Deforest) = d.resultIdToResult(i) + def handleCtorIds[A](k: (ResultId, Select | Value.Ref, ClsOrModSymbol, Ls[Arg]) => A)(using Deforest) = + def handleCallLike(f: Path, args: Ls[Arg]) = f match + case s: Select if s.symbol.flatMap(_.asCls).isDefined => + Some(k(i, s, s.symbol.get.asCls.get, args)) + case v: Value.Ref if v.l.asCls.isDefined => + Some(k(i, v, v.l.asCls.get, args)) + case _ => None + i.getResult match + case Call(fun, args) => handleCallLike(fun, args) + case Instantiate(cls, args) => handleCallLike(cls, args.map(Arg(false, _))) + case s: Select if s.symbol.flatMap(_.asObj).isDefined => + Some(k(i, s, s.symbol.get.asObj.get, Nil)) + case v: Value.Ref if v.l.asObj.isDefined => + Some(k(i, v, v.l.asObj.get, Nil)) + case _ => None + def getClsSymOfUid(using Deforest) = i.handleCtorIds((_, _, s, _) => s).get + +case class Ctor(ctor: ClsOrModSymbol, args: Map[TermSymbol, ProdStrat], expr: ResultId) extends ProdStrat +case class ProdFun(l: Ls[ConsStrat], r: ProdStrat) extends ProdStrat +case class ProdVar(s: StratVarState) extends ProdStrat with StratVarTrait(s) +case object NoProd extends ProdStrat + + + +class Dtor(val expr: Match, val outterMatch: Option[ResultId])(using d: Deforest) extends ConsStrat: + d.matchScrutToMatchBlock.updateWith(expr.scrut.uid): + case None => Some(expr) + case Some(_) => lastWords(s"should only update once (uid: ${expr.scrut.uid})") + d.matchScrutToParentMatchScrut.updateWith(expr.scrut.uid): + case None => Some(outterMatch) + case Some(_) => lastWords(s"should only update once (uid: ${expr.scrut.uid})") +object Dtor: + def unapply(d: Dtor)(using Deforest): Opt[ResultId] = S(d.expr.scrut.uid) + + +case class FieldSel(field: Tree.Ident, consVar: ConsVar)(val expr: ResultId, val inMatching: LinkedHashMap[ResultId, ClsOrModSymbol]) extends ConsStrat with FieldSelTrait +case class ConsFun(l: Ls[ProdStrat], r: ConsStrat) extends ConsStrat +case class ConsVar(s: StratVarState) extends ConsStrat with StratVarTrait(s) +case object NoCons extends ConsStrat + + +enum DtorExpr: + case Match(s: ResultId) + case Sel(s: ResultId) + +enum CtorFinalDest: + case Match(scrut: ResultId, expr: codegen.Match, selInArms: Ls[ResultId], selMaps: Map[Tree.Ident, Symbol] -> Map[ResultId, Symbol]) + case Sel(s: ResultId) + +trait FieldSelTrait: + this: FieldSel => + val filter = mutable.Map.empty[ProdVar, Ls[ClsOrModSymbol]].withDefaultValue(Nil) + + def updateFilter(p: ProdVar, c: Ls[ClsOrModSymbol]) = + filter += p -> (c ::: filter(p)) + +trait StratVarTrait(stratState: StratVarState): + this: ProdVar | ConsVar => + + lazy val asProdStrat = stratState.asProdStrat + lazy val asConsStrat = stratState.asConsStrat + lazy val uid = stratState.uid + +final case class NotDeforestableException(msg: String) extends Exception(msg) + +// Compute free vars for a block, without considering deforestation. +// Used on blocks after the deforestation transformation. +// This means that for matches we don't need to consider the extra +// free vars that may be introduced by deforestation: +// 1. the free vars from the `rest` of the their parent matches +// 2. the free vars caused by the substitution of selections of scrutinees of their parent matches +class FreeVarTraverser(alwaysDefined: Set[Symbol]) extends BlockTraverser: + val ctx = mutable.Set.from(alwaysDefined) + val result = mutable.Set.empty[Symbol] + + override def applyBlock(b: Block): Unit = b match + case Match(scrut, arms, dflt, rest) => + applyPath(scrut) + (arms.map(_._2) ++ dflt).foreach: a => + // dflt may just be `throw error``, and `rest` may use vars assigned in non default arms. + // So use `flattened` to remove dead code (after `throw error`) and spurious free vars. + val realArm = Begin(a, rest) + applyBlock(realArm) + + case Assign(lhs, rhs, rest) => + applyResult(rhs) + ctx += lhs + applyBlock(rest) + ctx -= lhs + case Begin(sub, rest) => applyBlock(b.flattened) + case Define(defn, rest) => defn match + case FunDefn(owner, sym, params, body) => + val paramSymbols = params.flatMap: + case ParamList(_, params, restParam) => (params ++ restParam).map: + case Param(sym = sym, _) => sym + ctx += sym + ctx ++= paramSymbols + applyBlock(body) + ctx --= paramSymbols + applyBlock(rest) + ctx -= sym + case ValDefn(owner, k, sym, rhs) => + ctx += sym + applyPath(rhs) + applyBlock(rest) + ctx -= sym + case c: ClsLikeDefn => ??? // not supported + + case _ => super.applyBlock(b) + + override def applyValue(v: Value): Unit = v match + case Value.Ref(l) => l match + // builtin symbols and toplevel symbols are always in scope + case _: (BuiltinSymbol | TopLevelSymbol) => () + // NOTE: assume all class definitions are in the toplevel + case b: BlockMemberSymbol if b.asClsLike.isDefined => () + case _ => if !ctx.contains(l) then result += l + case _ => super.applyValue(v) + + override def applyLam(l: Value.Lam): Unit = + val paramSymbols = l.params.params.map(p => p.sym) + ctx ++= paramSymbols + applyBlock(l.body) + ctx --= paramSymbols + + def analyze(b: Block) = + applyBlock(b) + result.toList.sortBy(_.uid) + +// Compute free vars for a Match block, considering deforestations. Used on non-transformed blocks +// Make use of `freeVarsOfNonTransformedMatches`, which computes the free vars _after transformation_ +// from non-transformed matches (either fusing or un-fusing matches). +// Sources of additional free vars for fusing matches: +// - the free vars from the `rest` of the their parent matches +// - the free vars caused by the substitution of selections of scrutinees of their parent matches +// Otherwise, additional free vars come from fusing matches that are contained in the block, +// and this is handled by freeVarsOfNonTransformedMatches +class DeforestationFreeVarTraverserForMatch( + alwaysDefined: Set[Symbol], + selsToBeReplaced: Map[ResultId, Symbol], + selsReplacementByCurrentMatch: Map[ResultId, Symbol], + currentMatchScrut: Symbol, + dt: DeforestTransformer +) extends FreeVarTraverser(alwaysDefined): + given Deforest = dt.d + override def applyBlock(b: Block): Unit = b match + // a nested match + case m@Match(scrut, arms, dflt, rest) => + result ++= dt.freeVarsOfNonTransformedMatches(scrut.uid, m) + + // sub-matches' scruts (which are not included in freeVarsOfNonTransformedMatches) + // are also free vars + val Value.Ref(l) = scrut + if !ctx(l) then result += l + + // free vars in nested-matches reported by freeVarsOfNonTransformedMatches may also contain + // spurious ones: those that are going to be substitued by the current match, + // and those that are in the ctx + result --= selsReplacementByCurrentMatch.values + result --= ctx + case _ => super.applyBlock(b) + + override def applyPath(p: Path): Unit = p match + case p @ Select(qual, name) => selsToBeReplaced.get(p.uid) match + case None => qual match + // if it is the scrut of current match and the computation containing + // this selection is moved, then the selection will be replaced and there will be no free vars + case Value.Ref(l) if l == currentMatchScrut => () + case _ => super.applyPath(p) + case Some(s) => result += s + case _ => super.applyPath(p) + + override def analyze(m: Block): List[Symbol] = + require(m.isInstanceOf[Match]) + val matchExpr@Match(scrut@Value.Ref(l), arms, dflt, rest) = m + val parentMatchRest = dt.allParentMatches(scrut.uid).foldRight[Block](End("")): (p, acc) => + Begin(dt.d.matchScrutToMatchBlock(p).rest, acc) + (arms.map(_._2) ++ dflt).foreach: a => + // dflt may just be `throw error``, and `rest` may use vars assigned in non default arms. + // So use `flattened` to remove dead code (after `throw error`) and spurious free vars. + // Also take care of the `rest`s of its parent match blocks. + val realArm = Begin(a, Begin(rest, parentMatchRest)).flattened + applyBlock(realArm) + + result.toList.sortBy(_.uid) + + +class WillBeNonEndTailBlockTraverser(using d: Deforest) extends BlockTraverserShallow: + var flag = false + override def applyBlock(b: Block): Unit = b match + case Match(scrut, arms, dflt, rest) => + flag = + d.rewritingMatchConsumers(scrut.uid) || + (arms.forall { case (_, b) => b.willBeNonEndTailBlock } && dflt.fold(true)(_.willBeNonEndTailBlock)) || + rest.willBeNonEndTailBlock + case _: End => () + case _: BlockTail => flag = true + case _ => super.applyBlock(b) + def analyze(b: Block): Bool = + applyBlock(b) + flag + +class ReplaceLocalSymTransformer(freeVarsAndTheirNewSyms: Map[Symbol, Symbol]) extends BlockTransformer(new SymbolSubst()): + override def applyValue(v: Value): Value = v match + case Value.Ref(l) => Value.Ref(freeVarsAndTheirNewSyms.getOrElse(l, l)) + case _ => super.applyValue(v) + +class HasExplicitRetTraverser extends BlockTraverserShallow: + var flag = false + override def applyBlock(b: Block): Unit = b match + case Return(_, imp) => flag = !imp + case _ => super.applyBlock(b) + + def analyze(b: Block) = + flag = false + applyBlock(b) + flag + +class GetCtorsTraverser(using Deforest) extends BlockTraverser: + val ctors = mutable.Set.empty[ResultId] + override def applyResult(r: Result): Unit = + r.uid.handleCtorIds{ (id, f, clsOrMod, args) => + ctors += id + args.foreach { case Arg(_, value) => applyResult(value) } + } match + case Some(_) => () + case None => r match + case Call(_, args) => + args.foreach { case Arg(_, value) => applyResult(value) } + case Instantiate(cls, args) => + args.foreach(applyResult) + case _ => () + +extension (b: Block) + def replaceSymbols(freeVarsAndTheirNewSyms: Map[Symbol, Symbol]) = + ReplaceLocalSymTransformer(freeVarsAndTheirNewSyms).applyBlock(b) + + def sortedFvsForTransformedBlocks(alwaysDefined: Set[Symbol]) = + FreeVarTraverser(alwaysDefined).analyze(b) + + def hasExplicitRet: Boolean = + HasExplicitRetTraverser().analyze(b) + + def willBeNonEndTailBlock(using d: Deforest): Bool = + WillBeNonEndTailBlockTraverser().analyze(b) + +class Deforest(using TL, Raise, Elaborator.State): + + given Uid.StratVar.State = Uid.StratVar.State() + given Deforest = this + import StratVarState.freshVar + + + val resultIdToResult = mutable.Map.empty[ResultId, Result] + + def apply(p: Program): Opt[Program] -> String -> Int = + val mainBlk = p.main + + globallyDefinedVars.init(mainBlk) + + // allocate type vars for defined symbols in the blocks + symToStrat.init(mainBlk) + + try + processBlock(mainBlk) + catch + case NotDeforestableException(msg) => + // return None if deforestation is not applicable + return N -> "" -> 0 + + resolveConstraints + + tl.log("-----------------------------------------") + ctorDests.ctorDests.foreach: + case (ctorExprId, CtorDest(matches, sels, noCons)) => tl.log: + val ctorName = ctorExprId.getClsSymOfUid.nme + s"(id:$ctorExprId)" + val matchExprScruts = "if " + matches.map{(s, m) => + m.scrut.asInstanceOf[Value.Ref].l.nme + s"(id:$s)" + }.toList.sorted.mkString(" | ") + " then ... " + val selExpr = sels.map{ + case sel@FieldSel(s, v) => s".${s.name}(id:${sel.expr})" + }.toList.sorted.mkString(" | ") + s"$ctorName\n\t --- match ---> $matchExprScruts\n\t --- sels ---> $selExpr\n\tNoCons: $noCons" + tl.log("-----------------------------------------") + dtorSources.dtorSources.foreach: + case (d, DtorSource(ctors, noProd)) => + tl.log(s"$d <--- ${ctors.map(c => c.getClsSymOfUid.nme + s"(id:$c)").toList.mkString(" | ")} <--- (NoProd: $noProd)") + tl.log("-----------------------------------------") + filteredCtorDests.foreach: + case (ctorUid, CtorFinalDest.Sel(s)) => tl.log(s"${ctorUid.getClsSymOfUid.nme}(id:$ctorUid) --sel--> " + s) + case (ctorUid, CtorFinalDest.Match(scrut, _, _, _)) => tl.log(s"${ctorUid.getClsSymOfUid.nme}(id:$ctorUid) --mat--> " + scrut ) + + val fusionStat = filteredCtorDests.map: + case (ctorUid, CtorFinalDest.Sel(s)) => + "\t" + ctorUid.getClsSymOfUid.nme + " --sel--> " + s"`.${resultIdToResult(s).asInstanceOf[Select].name}`" + case (ctorUid, CtorFinalDest.Match(scrut, expr, _, _)) => + "\t" + ctorUid.getClsSymOfUid.nme + " --match--> " + s"`if ${expr.scrut.asInstanceOf[Value.Ref].l.nme} is ...`" + + if filteredCtorDests.nonEmpty then + S(Program(p.imports, rewrite(mainBlk))) -> s"${filteredCtorDests.size} fusion opportunities:\n${fusionStat.toList.sorted.mkString("\n")}" -> filteredCtorDests.size + else + S(p) -> s"0 fusion opportunity" -> 0 + + object globallyDefinedVars: + val store = mutable.Set.from[Symbol](State.globalThisSymbol ::State.runtimeSymbol :: Nil) + + def apply(s: Symbol) = store.contains(s) + + def init(b: Block) = store ++= b.definedVars + + var constraints: Ls[ProdStrat -> ConsStrat] = Nil + + val matchScrutToMatchBlock = mutable.Map.empty[ResultId, Match] + val matchScrutToParentMatchScrut = mutable.Map.empty[ResultId, Option[ResultId]] + object symToStrat: + val store = mutable.Map.empty[Symbol, ProdVar] + val funSymsWithDefn = mutable.Set.empty[BlockMemberSymbol] + val usedFunSym = mutable.Set.empty[BlockMemberSymbol] + + def init(p: Block) = + if store.isEmpty then + object FreshVarForAllVars extends BlockTraverser: + override def applySymbol(s: Symbol): Unit = s match + case b: BlockMemberSymbol => + store += s -> freshVar(s.nme)._1 + b.trmImplTree.foreach: t => + if t.k is syntax.Fun then usedFunSym += b + case _: TempSymbol => store += s -> freshVar(s.nme)._1 + case _: VarSymbol => store += s -> freshVar(s.nme)._1 + case _: TermSymbol => store += s -> freshVar(s.nme)._1 + case _ => () + + override def applyFunDefn(fun: FunDefn): Unit = + funSymsWithDefn += fun.sym + super.applyFunDefn(fun) + + FreshVarForAllVars.applyBlock(p) + // `NoProd` to block fusion for those functions that are imported from elsewhere + usedFunSym.diff(funSymsWithDefn).foreach: funSymsWithoutDefn => + constrain(NoProd, store(funSymsWithoutDefn).asConsStrat) + + + def getStratOfSym(s: Symbol) = + s match + case _: BuiltinSymbol => NoProd + case _: TopLevelSymbol => NoProd + // TODO: cannot fuse intermediate values created by + // calling data constructors passed around like functions, + // like `fun app(ctor) = ctor(1); if app(AA) is AA(x) then x`; + // immediate data constructor calls are handled directly, + // so if this method is called on a ClsLike symbol, + // it means that this constructor is passed around like a function, + // which we can't fuse for now + case _ if s.asCls.isDefined => NoProd + case _: BlockMemberSymbol => store(s) + case _: LocalSymbol => store(s) + def +=(e: Symbol -> ProdVar) = store += e + def addAll(es: Iterable[Symbol -> ProdVar]) = store.addAll(es) + def apply(s: Symbol) = store(s) + + def getClsFields(s: ClassSymbol) = s.tree.clsParams + + + + def constrain(p: ProdStrat, c: ConsStrat) = constraints ::= p -> c + + def processBlock(b: Block)(using + inArm: Map[ProdVar, ClsOrModSymbol] = Map.empty[ProdVar, ClsOrModSymbol], + matching: LinkedHashMap[ResultId, ClsOrModSymbol] = LinkedHashMap.empty[ResultId, ClsOrModSymbol] + ): ProdStrat = b match + case m@Match(scrut, arms, dflt, rest) => + val scrutStrat = processResult(scrut) + constrain(scrutStrat, Dtor(m, matching.lastOption.map(_._1))) + val armsRes = if arms.forall{ case (cse, _) => cse.isInstanceOf[Case.Cls] } then + arms.map: + case (Case.Cls(s, _), body) => + processBlock(body)( + using inArm + (scrutStrat.asInstanceOf[ProdVar] -> s), + matching + (scrut.uid -> s) + ) + else + arms.map: + case (_, armBody) => processBlock(armBody) + val dfltRes = dflt.map(processBlock) + rest match + case End(msg) => + val matchRes = freshVar() + armsRes.appendedAll(dfltRes).foreach: r => + constrain(r, matchRes._2) + matchRes._1 + case _ => processBlock(rest) + + case Return(res, implct) => processResult(res) + case Assign(lhs, rhs, rest) => + constrain(processResult(rhs), symToStrat(lhs).asConsStrat) + processBlock(rest) + case Begin(sub, rest) => + processBlock(sub) + processBlock(rest) + case Define(defn, rest) => + defn match + case FunDefn(_, sym, params, body) => + val funSymStratVar = symToStrat(sym) + val param = params match + // TODO: handle `restParam` and mutiple param list + case ParamList(flags, params, N) :: Nil => params + val funStrat = constrFun(param, body) + constrain(funStrat, funSymStratVar.asConsStrat) + funSymStratVar + case v: ValDefn => throw NotDeforestableException("No support for `ValDefn` yet") + case c: ClsLikeDefn => throw NotDeforestableException("No support for `ClsLikeDefn` yet") + processBlock(rest) + case End(msg) => NoProd + // make it a type var instead of `NoProd` so that things like `throw match error` in + // default else branches do not block fusion... + case Throw(exc) => + processResult(exc) + freshVar("throw")._1 + + def constrFun(params: Ls[Param], body: Block)(using + inArm: Map[ProdVar, ClsOrModSymbol], + matching: LinkedHashMap[ResultId, ClsOrModSymbol] + ) = + val paramSyms = params.map: + case Param(sym = sym, _) => sym + val paramStrats = paramSyms.map(symToStrat.apply) + symToStrat.addAll(paramSyms.zip(paramStrats)) + val res = freshVar() + constrain(processBlock(body), res._2) + ProdFun(paramStrats.map(s => s.asConsStrat), res._1) + + def processResult(r: Result)(using + inArm: Map[ProdVar, ClsOrModSymbol], + matching: LinkedHashMap[ResultId, ClsOrModSymbol] + ): ProdStrat = + def handleCallLike(f: Path, args: Ls[Path], c: Result) = + val argsTpe = args.map(processResult) + f match + case s@Select(p, nme) => + s.symbol.map(_.asCls) match + case None => + val pStrat = processResult(p) + val tpeVar = freshVar() + constrain(pStrat, FieldSel(nme, tpeVar._2)(s.uid, matching)) + val appRes = freshVar() + constrain(tpeVar._1, ConsFun(argsTpe, appRes._2)) + appRes._1 + case Some(None) => + val funSym = s.symbol.get + val appRes = freshVar("call_" + funSym.nme + "_res") + constrain(symToStrat.getStratOfSym(funSym), ConsFun(argsTpe, appRes._2)) + appRes._1 + case Some(Some(s)) => + val clsFields = getClsFields(s) + Ctor(s, clsFields.zip(argsTpe).toMap, c.uid) + case Value.Ref(l) => + l.asCls match + case Some(s) => + val clsFields = getClsFields(s) + Ctor(s, clsFields.zip(argsTpe).toMap, c.uid) + case _ => // then it is a function + val appRes = freshVar("call_" + l.nme + "_res") + constrain(symToStrat.getStratOfSym(l), ConsFun(argsTpe, appRes._2)) + appRes._1 + case lam@Value.Lam(params, body) => + val funTpe = processResult(lam) + val appRes = freshVar() + constrain(funTpe, ConsFun(argsTpe, appRes._2)) + appRes._1 + + case Value.This(sym) => throw NotDeforestableException("No support for `this` as a callee yet") + case Value.Lit(lit) => ??? + case Value.Arr(elems) => ??? + r match + case c@Call(f, args) => handleCallLike(f, args.map {case Arg(false, value) => value}, c) + + case i@Instantiate(cls, args) => handleCallLike(cls, args, i) + + case sel@Select(p, nme) => sel.symbol match + case Some(s) if s.asObj.isDefined => + Ctor(s.asObj.get, Map.empty, sel.uid) + case _ => + val pStrat = processResult(p) + pStrat match + case ProdVar(pStratVar) if inArm.contains(pStratVar.asProdStrat) => + val tpeVar = freshVar() + val selStrat = FieldSel(nme, tpeVar._2)(sel.uid, matching) + selStrat.updateFilter(pStratVar.asProdStrat, inArm(pStratVar.asProdStrat) :: Nil) + constrain(pStrat, selStrat) + tpeVar._1 + case _ => + val tpeVar = freshVar() + constrain(pStrat, FieldSel(nme, tpeVar._2)(sel.uid, matching)) + tpeVar._1 + + case v@Value.Ref(l) => l.asObj match + case None => symToStrat.getStratOfSym(l) + case Some(m) => Ctor(m, Map.empty, v.uid) + + case Value.This(sym) => throw NotDeforestableException("No support for `this` yet") + case Value.Lit(lit) => NoProd + case Value.Lam(ParamList(_, params, N), body) => + constrFun(params, body) + case Value.Arr(elems) => throw NotDeforestableException("No support for arrays yet") + + + val upperBounds = mutable.Map.empty[StratVarId, Ls[ConsStrat]].withDefaultValue(Nil) + val lowerBounds = mutable.Map.empty[StratVarId, Ls[ProdStrat]].withDefaultValue(Nil) + + case class CtorDest(matches: Map[ResultId, Match], sels: Ls[FieldSel], noCons: Bool) + case class DtorSource(ctors: Set[ResultId], noProd: Bool) + object ctorDests: + val ctorDests = mutable.LinkedHashMap.empty[ResultId, CtorDest].withDefaultValue(CtorDest(Map.empty, Nil, false)) + def update(ctor: ResultId, m: Match) = ctorDests.updateWith(ctor): + case Some(CtorDest(matches, sels, noCons)) => Some(CtorDest(matches + (m.scrut.uid -> m), sels, noCons)) + case None => Some(CtorDest(Map(m.scrut.uid -> m), Nil, false)) + def update(ctor: ResultId, s: FieldSel) = ctorDests.updateWith(ctor): + case Some(CtorDest(matches, sels, noCons)) => Some(CtorDest(matches, s :: sels, noCons)) + case None => Some(CtorDest(Map.empty, s :: Nil, false)) + def update(ctor: ResultId, n: NoCons.type) = ctorDests.updateWith(ctor): + case Some(CtorDest(matches, sels, noCons)) => Some(CtorDest(matches, sels, true)) + case None => Some(CtorDest(Map.empty, Nil, true)) + def get(ctor: ResultId) = ctorDests.get(ctor) + + object dtorSources: + val dtorSources = mutable.Map.empty[DtorExpr, DtorSource].withDefaultValue(DtorSource(Set.empty, false)) + private def getDtorExprOfResultId(i: ResultId) = i.getResult match + case s: Select => DtorExpr.Sel(i) + case r: Value.Ref => DtorExpr.Match(i) + case r => lastWords(s"try to get dtor expr from ResultId, but get $r") + def update(dtor: ResultId, ctor: ResultId) = + val dtorExpr = getDtorExprOfResultId(dtor) + dtorSources.updateWith(dtorExpr): + case None => Some(DtorSource(Set(ctor), false)) + case Some(DtorSource(ctors, noProd)) => Some(DtorSource(ctors + ctor, noProd)) + def update(dtor: ResultId, noProd: NoProd.type) = + val dtorExpr = getDtorExprOfResultId(dtor) + dtorSources.updateWith(dtorExpr): + case None => Some(DtorSource(Set.empty, true)) + case Some(DtorSource(ctors, noProd)) => Some(DtorSource(ctors, true)) + def get(dtor: ResultId) = dtorSources.get(getDtorExprOfResultId(dtor)) + + + + def resolveConstraints: Unit = + + def handle(c: ProdStrat -> ConsStrat)(using cache: mutable.Set[ProdStrat -> ConsStrat]): Unit = + val prod = c._1 + val cons = c._2 + + if cache(c) then return () + + cache += c + + (prod, cons) match + case (Ctor(ctor, args, expr), dtorStrat@Dtor(scrut)) => + ctorDests.update(expr, dtorStrat.expr) + dtorSources.update(scrut, expr) + case (Ctor(ctor, args, expr), selDtor@FieldSel(field, consVar)) => + ctorDests.update(expr, selDtor) + dtorSources.update(selDtor.expr, expr) + args.find(a => a._1.id == field).map: p => + handle(p._2 -> consVar) + case (Ctor(ctor, args, _), ConsFun(l, r)) => () // ignore + + case (p: ProdVar, _) => + upperBounds += p.uid -> (cons :: upperBounds(p.uid)) + lowerBounds(p.uid).foreach: l => + (l, cons) match + case (l: ProdVar, sel@FieldSel(field, consVar)) => + sel.updateFilter(l, sel.filter(p)) + handle(l -> cons) + case (Ctor(ctor, args, _), sel@FieldSel(field, consVar)) => + if sel.filter.get(p).forall(_.contains(ctor)) then + handle(l -> cons) + else + () + case _ => handle(l -> cons) + case (_, c: ConsVar) => + lowerBounds += c.uid -> (prod :: lowerBounds(c.uid)) + upperBounds(c.uid).foreach: u => + (prod, u) match + case (Ctor(ctor, args, _), sel@FieldSel(field, consVar)) => + if sel.filter.get(c.asProdStrat).forall(_.contains(ctor)) then + handle(prod -> u) + else + () + case (_: ProdVar, _) => die + case _ => handle(prod -> u) + case (Ctor(ctor, args, expr), NoCons) => + ctorDests.update(expr, NoCons) + args.valuesIterator.foreach(a => handle(a, NoCons)) + case (ProdFun(l, r), Dtor(cls)) => () // ignore + case (ProdFun(l, r), FieldSel(field, consVar)) => () // ignore + case (ProdFun(lp, rp), ConsFun(lc, rc)) => + lc.zip(lp).foreach(handle) + handle(rp, rc) + case (ProdFun(l, r), NoCons) => + l.foreach(a => handle(NoProd, a)) + handle(r, NoCons) + case (NoProd, Dtor(scrut)) => dtorSources.update(scrut, NoProd) + case (NoProd, fSel@FieldSel(field, consVar)) => dtorSources.update(fSel.expr, NoProd) + case (NoProd, ConsFun(l, r)) => + l.foreach(a => handle(a, NoCons)) + handle(NoProd, r) + case (NoProd, NoCons) => () + + constraints.foreach(c => handle(c)(using mutable.Set.empty)) + + + // ======== after resolving constraints ====== + + lazy val resolveClashes = + val ctorToDtor = ctorDests.ctorDests + val dtorToCtor = dtorSources.dtorSources + + def removeCtor(rm: ResultId): Unit = + for CtorDest(mat, sels, _) <- ctorToDtor.remove(rm) do + for s <- mat.keys do removeDtor(DtorExpr.Match(s)) + for s <- sels do removeDtor(DtorExpr.Sel(s.expr)) + + def removeDtor(rm: DtorExpr) = + for + c <- dtorToCtor.remove(rm) + x <- c.ctors + do + removeCtor(x) + + // remove clashes: + ctorToDtor.filterNot { case _ -> CtorDest(dtors, sels, noCons) => + ((dtors.size == 0 && sels.size == 1) + || (dtors.size == 1 && { + val scrutRef@Value.Ref(scrut) = dtors.head._1.getResult + sels.forall { s => s.expr.getResult match + case Select(Value.Ref(l), nme) => (l === scrut) && s.inMatching.contains(scrutRef.uid) // need to be in the matching arms, and checking the scrutinee + case _ => false } + })) + && !noCons + }.keys.foreach(removeCtor) + dtorToCtor.filter(_._2.noProd).keys.foreach(removeDtor) + + // remove cycle: + def getCtorInArm(ctor: ResultId, dtor: Match) = + val ctorSym = getClsSymOfUid(ctor) + val arm = dtor.arms.find{ case (Case.Cls(c1, _) -> body) => c1 === ctorSym }.map(_._2).orElse(dtor.dflt).get + val traverser = GetCtorsTraverser() + traverser.applyBlock(arm) + traverser.ctors + + def findCycle(ctor: ResultId, dtor: Match): Ls[ResultId] = + val cache = mutable.Set(ctor) + def go(ctorAndMatches: Ls[ResultId -> Match]): Ls[ResultId] = + var newCtorsAndNewMatches: Ls[ResultId -> Match] = Nil + for + (c, m) <- ctorAndMatches + c <- getCtorInArm(c, m) + CtorDest(matches, sels, _) <- ctorToDtor.get(c) + m <- matches.values.headOption + do newCtorsAndNewMatches = (c -> m) :: newCtorsAndNewMatches + val cycled = newCtorsAndNewMatches.filter(c => !cache.add(c._1)) + if newCtorsAndNewMatches.isEmpty then + Nil + else if cycled.nonEmpty then + cycled.map(_._1) + else + go(newCtorsAndNewMatches) + go(Ls(ctor -> dtor)) + + for + (c, CtorDest(matches, sels, _)) <- ctorToDtor + m <- matches.values + x <- findCycle(c, m) + do removeCtor(x) + + ctorToDtor -> dtorToCtor + + + + lazy val filteredCtorDests: Map[ResultId, CtorFinalDest] = + val res = mutable.Map.empty[ResultId, CtorFinalDest] + + // we need only one CtorFinalDest per arm for each pat mat expr + val handledMatches = mutable.Map.empty[ResultId -> ClsOrModSymbol, CtorFinalDest] + + resolveClashes._1.foreach { case (ctor, CtorDest(dtors, sels, false)) => + val filteredDtor = { + if dtors.size == 0 && sels.size == 1 then CtorFinalDest.Sel(sels.head.expr) + else if dtors.size == 0 && sels.size > 1 then + lastWords("more than one consumer") + else if dtors.size > 1 then + lastWords("more than one consumer") + else if dtors.size == 1 then + val currentCtorCls = getClsSymOfUid(ctor) + val scrutRef@Value.Ref(scrut) = dtors.head._1.getResult + handledMatches.getOrElseUpdate( + scrutRef.uid -> currentCtorCls, + if sels.forall{ s => s.expr.getResult match + case Select(Value.Ref(l), nme) => (l === scrut) && s.inMatching.contains(scrutRef.uid) + case _ => false + } then + val fieldNameToSymToBeReplaced = mutable.Map.empty[Tree.Ident, Symbol] + val selectionUidsToSymToBeReplaced = mutable.Map.empty[ResultId, Symbol] + + dtors.head._2.arms.foreach: + case (Case.Cls(cOrMod, _), body) if cOrMod.asCls.fold(false)(_ === currentCtorCls) => + val c = cOrMod.asCls.get + // if this arm is used more than once, should be var symbol because the arm body will be + // extracted to a function, otherwise just temp symbol + val varSymInsteadOfTempSym = resolveClashes._2(DtorExpr.Match(dtors.head._1)).ctors.count(getClsSymOfUid(_) === c) > 1 + val selsInArms = sels.filter { fs => fs.inMatching(dtors.head._1) === c } + + selsInArms.foreach: fs => + assert(getClsFields(c).map(_.id).contains(fs.field)) + fieldNameToSymToBeReplaced.updateWith(fs.field): + case Some(v) => Some(v) + case None => Some(if varSymInsteadOfTempSym + then VarSymbol(Tree.Ident(s"_deforest_${c.name}_${fs.field.name}")) + else TempSymbol(N, s"_deforest_${c.name}_${fs.field.name}")) + val sym = fieldNameToSymToBeReplaced(fs.field) + + selectionUidsToSymToBeReplaced.addOne(fs.expr -> sym) + case _ => () + CtorFinalDest.Match( + dtors.head._1, + dtors.head._2, + sels.map(_.expr), + fieldNameToSymToBeReplaced.toMap -> selectionUidsToSymToBeReplaced.toMap + ) + else + lastWords("more than one consumer") + ) + else die + } + res.updateWith(ctor){_ => Some(filteredDtor)} + } + res.toMap + + lazy val rewritingSelConsumers = filteredCtorDests.values.collect { + case CtorFinalDest.Sel(s) => s + }.toSet + + lazy val rewritingMatchConsumers = filteredCtorDests.values.collect { + case CtorFinalDest.Match(scrut = s, _) => s + }.toSet + + def rewrite(p: Block) = + val deforestTransformer = DeforestTransformer() + val rest = deforestTransformer.applyBlock(p) + val newDefsRest = deforestTransformer.matchRest.getAllFunDefs + val newDefsArms = deforestTransformer.matchArms.getAllFunDefs + newDefsArms(newDefsRest(rest)) + +class DeforestTransformer(using val d: Deforest, elabState: Elaborator.State) extends BlockTransformer(new SymbolSubst()): + self => + val nonFreeVars: Set[Symbol] = d.globallyDefinedVars.store.toSet + + val replaceSelInfo: Map[ResultId, Symbol] = + d.filteredCtorDests.values.flatMap { + case CtorFinalDest.Match(_, _, _, selMaps) => + selMaps._2 + case CtorFinalDest.Sel(s) => Nil + }.toMap + + def parentMatchesUptoAFusingOne(scrutId: ResultId) = + def go(scrutId: ResultId): List[ResultId] -> Opt[ResultId] = + d.matchScrutToParentMatchScrut(scrutId).fold(Nil -> N): r => + if d.rewritingMatchConsumers.contains(r) + then Nil -> S(r) + else + val res = go(r) + (r :: res._1) -> res._2 + go(scrutId) + + def allParentMatches(scrutId: ResultId) = + def go(scrutId: ResultId): List[ResultId] = + d.matchScrutToParentMatchScrut(scrutId).fold(Nil)(r => r :: go(r)) + go(scrutId) + + object freeVarsOfNonTransformedMatches: + val store = mutable.Map.empty[ResultId, List[Symbol]] + + private val toBeReplacedForAllBranches = mutable.Map.empty[ResultId, Map[ResultId, Symbol]].withDefaultValue(Map.empty) + d.filteredCtorDests.values.foreach: + case CtorFinalDest.Match(scrut, expr, selInArms, selMaps) => + toBeReplacedForAllBranches += scrut -> (toBeReplacedForAllBranches(scrut) ++ selMaps._2) + case CtorFinalDest.Sel(s) => () + + def apply(scrutExprId: ResultId, m: Match) = store.getOrElseUpdate( + scrutExprId, + locally: + assert(m.scrut.uid === scrutExprId) + val Match(Value.Ref(l), _, _, _) = m + val selReplacementNotForThisSel = replaceSelInfo -- toBeReplacedForAllBranches(scrutExprId).keys + DeforestationFreeVarTraverserForMatch( + nonFreeVars, + selReplacementNotForThisSel, + toBeReplacedForAllBranches(scrutExprId), + l, + self + ).analyze(m) + ) + + object matchArms: + val store = LinkedHashMap.empty[ResultId, Map[ClsOrModSymbol | None.type, FunDefn]].withDefaultValue(Map.empty) + + // return a lambda, which either calls the extracted arm function, or contains the computations in matching arms + def getOrElseUpdate( + scrut: ResultId, + m: Match, + cls: ClsOrModSymbol, + sel: Set[ResultId], + currentUsedCtorArgsToFields: Map[Tree.Ident, Value.Ref], + preComputedSymbols: Map[Tree.Ident, Symbol] -> Map[ResultId, Symbol] = Map.empty -> Map.empty + ) = + assert(scrut === m.scrut.uid) + val freeVarsAndTheirNewSyms = freeVarsOfNonTransformedMatches(scrut, m).map(s => s -> VarSymbol(Tree.Ident(s.nme))) + val (body, isDflt) = m.arms.find{ case (Case.Cls(c1, _) -> _) => c1 === cls }.map(_._2 -> false).orElse(m.dflt.map(_ -> true)).get + store.get(scrut).flatMap(_.get(if isDflt then None else cls)) match + case None => // not registered before, or this branch of this match will only appear once + val rest = m.rest + val makeBody = matchRest.getOrElseUpdate(scrut, rest) match + case N -> rewrittenRest => (bodyBlk: Block) => + Begin(bodyBlk, rewrittenRest).flattened.replaceSymbols(freeVarsAndTheirNewSyms.toMap).mapTail: + case Return(res, implct) => Return(res, false) + case t => t + case Some(f) -> rewrittenRest => (bodyBlk: Block) => + Begin( + bodyBlk, + Return( + Call( + Value.Ref(f), + rewrittenRest.sortedFvsForTransformedBlocks(nonFreeVars).map(a => Arg(false, Value.Ref(a))))(true, false), + false + ) + ).flattened.replaceSymbols(freeVarsAndTheirNewSyms.toMap).mapTail: + case Return(res, implct) => Return(res, false) + case t => t + + if d.resolveClashes._2(DtorExpr.Match(scrut)).ctors.count{c => + if !isDflt then c.getClsSymOfUid === cls + else m.arms.find{ case (Case.Cls(c1, _), _) => c1 === c.getClsSymOfUid }.isEmpty + } > 1 then + // make a function, and register, and return a lambda calling that function with correct arguments + // arguments for lambda: free vars + // arguments for that function: free vars and pattern vars + + val bodyReplaceSel = applyBlock(body) + + val freeVarsAndTheirNewSymsInLam = freeVarsAndTheirNewSyms.map(s => s._1 -> VarSymbol(s._2.id)) + val funBody = makeBody(bodyReplaceSel) + val funSym = BlockMemberSymbol(s"match_${scrut.getResult.asInstanceOf[Value.Ref].l.nme}_branch_${if isDflt then "dflt" else cls.nme}", Nil) + val newDef = FunDefn( + N, + funSym, + ParamList( + ParamListFlags.empty, + freeVarsAndTheirNewSyms.map(s => Param(FldFlags.empty, s._2, N, Modulefulness.none)).toList + ::: preComputedSymbols._1.toList.sortBy(_._1.name).map(v => + Param(FldFlags.empty, v._2.asInstanceOf[VarSymbol], N, Modulefulness.none) + ), + N + ) :: Nil, + funBody + ) + store += (scrut -> (store(scrut) + ((if isDflt then None else cls) -> newDef))) + Value.Lam( + ParamList(ParamListFlags.empty, freeVarsAndTheirNewSymsInLam.map(s => Param(FldFlags.empty, s._2, N, Modulefulness.none)), N), + Return( + Call(Value.Ref(funSym), freeVarsAndTheirNewSymsInLam.map(a => Arg(false, Value.Ref(a._2))) ::: currentUsedCtorArgsToFields.toList.sortBy(_._1.name).map(a => Arg(false, a._2)))(true, false), + false + ) + ) + else + val bodyReplaceSel = applyBlock(body) + val lambdaBody = makeBody(bodyReplaceSel) + Value.Lam( + ParamList(ParamListFlags.empty, freeVarsAndTheirNewSyms.values.map(s => Param(FldFlags.empty, s, N, Modulefulness.none)).toList, N), + lambdaBody + ) + + case Some(f) => + // return a lambda that calls f with correct arguments + Value.Lam( + ParamList(ParamListFlags.empty, freeVarsAndTheirNewSyms.map(s => Param(FldFlags.empty, s._2, N, Modulefulness.none)), N), + Return( + Call(Value.Ref(f.sym), freeVarsAndTheirNewSyms.map(a => Arg(false, Value.Ref(a._2))) ::: currentUsedCtorArgsToFields.toList.sortBy(_._1.name).map(a => Arg(false, a._2)))(true, false), + false + ) + ) + + def getAllFunDefs: Block => Block = + store.values.flatMap(v => v.values).foldRight(identity: Block => Block): + case (defn, k) => r => Define(defn, k(r)) + + object matchRest: + val store = LinkedHashMap.empty[ResultId, Opt[FunDefn] -> Block] + + def getAllDefined = store.valuesIterator.flatMap(_._1.map(_.sym)) + + // returns the symbol for the rest function (if any), and the rewritten rest block + def getOrElseUpdate(s: ResultId, restBeforeRewriting: Block): Opt[Symbol] -> Block = + store.get(s) match + case Some(f, b) => f.map(_.sym) -> b + case _ => + // return all blocks concat together using `Begin`, and return if all of them are `End` blocks + def concatAllRestBlocksOfMatches(ps: List[ResultId]) = + ps.foldRight[Block](End("")){ (pid, acc) => + val b = d.matchScrutToMatchBlock(pid).rest + val isEnd = b.isInstanceOf[End] + if isEnd then acc else Begin(b, acc) + } + val parentRestInfo = parentMatchesUptoAFusingOne(s) match + case ps -> Some(theFusingOne) => + // return the original rests from unfused parent matches, + // and the function symbol for the `rest` of the fusing parent match (if any) + // and the rewritten `rest` block of that fusing parent match + concatAllRestBlocksOfMatches(ps) -> + getOrElseUpdate(theFusingOne, d.matchScrutToMatchBlock(theFusingOne).rest) + case ps -> None => + // return the original rests from unfused parent matches, and none (meaning that there is no fusing parent match) + concatAllRestBlocksOfMatches(ps) -> None + + // bd: original `rest`s of non-fusing parent matches + val restRewritten = + val nonFlatten = parentRestInfo match + // None: there is no fusing parent match + case bd -> None => applyBlock(Begin(restBeforeRewriting, bd)) + // (Some(s), b): there is a fusing parent match, and its `rest` is extracted into a function with symbol `s`, + // and the transformed `rest` is b + case bd -> (Some(s), b) => Begin( + applyBlock(restBeforeRewriting), + Return(Call(Value.Ref(s), b.sortedFvsForTransformedBlocks(nonFreeVars ++ getAllDefined).map(a => Arg(false, Value.Ref(a))))(true, false), false)) + // (None, b): there is a fusing parent match, and its `rest` is not extracted into a function + case bd -> (None, b) => Begin(applyBlock(Begin(restBeforeRewriting, bd)), b) + nonFlatten.flattened + + // no need to build a new function for empty rest, or if the rest is only going to be used once + if restRewritten.isInstanceOf[End] || (d.resolveClashes._2(DtorExpr.Match(s)).ctors.map(c => c.getClsSymOfUid).size == 1) then + val res = N -> restRewritten + store += s -> res + res + else // build a new function and update the store + val scrutName = s.getResult.asInstanceOf[Value.Ref].l.nme + val sym = BlockMemberSymbol(s"match_${scrutName}_rest", Nil) + val freeVarsAndTheirNewSyms = restRewritten.sortedFvsForTransformedBlocks(nonFreeVars ++ getAllDefined).map(s => s -> VarSymbol(Tree.Ident(s.nme))).toMap + val newFunDef = FunDefn( + N, + sym, + ParamList(ParamListFlags.empty, freeVarsAndTheirNewSyms.values.map(s => Param(FldFlags.empty, s, N, Modulefulness.none)).toList, N) :: Nil, + restRewritten.replaceSymbols(freeVarsAndTheirNewSyms) + ) + store += s -> (Some(newFunDef) -> restRewritten) + Some(sym) -> restRewritten + + def getAllFunDefs: Block => Block = + store.values.foldRight(identity: Block => Block): + case (defn -> _, k) => + r => defn match + case None => k(r) + case Some(defn) => Define(defn, k(r)) + + + override def applyBlock(b: Block): Block = b match + case mat@Match(scrut, arms, dflt, rest) if arms.forall{ case (cse, _) => cse.isInstanceOf[Case.Cls] } && d.rewritingMatchConsumers.contains(scrut.uid) => + // since all fusing matches will be considered to be in the tail position, + // if any of the parent `rest`s has explicit return, the rewritten match will have explicit return + val oneOfParentMatchRestHasExplicitRet = allParentMatches(scrut.uid).foldRight(false) { (pid, acc) => acc || d.matchScrutToMatchBlock(pid).rest.hasExplicitRet } + val needExplicitRet = rest.hasExplicitRet || arms.exists(_._2.hasExplicitRet) || oneOfParentMatchRestHasExplicitRet + val freeVars = freeVarsOfNonTransformedMatches(scrut.uid, mat).map(v => Arg(false, Value.Ref(v))) + Return(Call(scrut, freeVars)(false, false), !needExplicitRet) + case Match(scrut, arms, dflt, rest) + if + // If all the arms end with non-`End` blocks, then the `rest` of this `Match` will never be executed, + // and we remove the `rest` in this case. This prevents `rest` to use variables that become + // undefined because computation in arms that defines them are moved away. + // One example illustrating the case of "deadcode using never assigned variable causing scope error during JS generation" is as follows: + // The mlscript program is: + // ``` + // fun test(x) = + // let t = if x is + // AA(AA(a)) then a + // t + 5 + // fun f(a) = if a is + // AA then 0 + // let p = AA(AA(10)) + // test(p) + f(p) + // ``` + // After lowering, it is essentially: + // ``` + // fun test(x) = + // if x is AA(param0) then + // if param0 is AA(param1) then + // a = param1 + // tmpRes = a + // else throw "match error" + // else throw "match error" + // t = tmpRes + // return t + 5 + // fun f(a) = if a is AA then 0 + // let p = AA(AA(10)) + // test(p) + f(p) + // ``` + // And after fusion, the program (before the removal of dead code causing scope error) is: + // ``` + // fun test(x) = + // if x is AA(param0) then + // param0() + // else throw "match error" + // t = tmpRes // <--- this `tmpRes` without binding site causes scope error + // return t + 5 + // fun f(a) = if a is AA then 0 + // let p = AA of + // () => a = 10; tmpRes = a; t = tmpRes; return t + 5; + // test(p) + f(p) + // ``` + // TODO: it will become unnecessary once we have proper binding declarations in the Block IR + // and all uses of never-assigned variables will be known to be dead code + dflt.fold(false)(_.willBeNonEndTailBlock) && arms.forall { case (_, body) => body.willBeNonEndTailBlock } + => + super.applyBlock(Match(scrut, arms, dflt, End(""))) + case _ => super.applyBlock(b) + + override def applyResult(r: Result): Result = r match + case _: Call => + // calls to fusing contructors are handled in `applyResult2` + // here we only handle calls to non-fusing constructors and functions + assert(!d.filteredCtorDests.isDefinedAt(r.uid)) + super.applyResult(r) + case _ => super.applyResult(r) + + override def applyResult2(r: Result)(k: Result => Block): Block = + def handleCallLike(f: Path, args: Ls[Path], uid: ResultId) = + val c = f match + case s: Select => s.symbol.get.asCls.get + case Value.Ref(l) => l.asCls.get + case _ => ??? + d.filteredCtorDests.get(uid).get match + case CtorFinalDest.Match(scrut, expr, sels, selsMap) => + // use pre-determined symbols, create temp symbols for un-used fields + val usedFieldIdentToSymbolsToBeReplaced = selsMap._1 + val allFieldIdentToSymbolsToBeReplaced = d.getClsFields(c).map: f => + f.id -> usedFieldIdentToSymbolsToBeReplaced.getOrElse(f.id, TempSymbol(N, s"_deforest_${c.name}_${f.id.name}_unused")) + + // if all vars are temp vars, no need to create more temp vars + // otherwise, create temps for var symbols (which will be function params with these temp vars flowing in) + val assignedTempSyms = + if allFieldIdentToSymbolsToBeReplaced.forall(_._2.isInstanceOf[TempSymbol]) then + allFieldIdentToSymbolsToBeReplaced.map(a => a._1 -> a._2.asInstanceOf[TempSymbol]) + else + allFieldIdentToSymbolsToBeReplaced.map { case (id, s) => s match + case ts: TempSymbol => id -> ts + case vs: VarSymbol => id -> TempSymbol(N, s"${vs.name}_tmp") + } + + val bodyAndRestInLam = matchArms.getOrElseUpdate( + scrut, + expr, + c, + sels.toSet, + assignedTempSyms.filter(a => usedFieldIdentToSymbolsToBeReplaced.contains(a._1)).map(a => a._1 -> Value.Ref(a._2).asInstanceOf[Value.Ref]).toMap, + selsMap._1 -> selsMap._2) + + args.zip(assignedTempSyms.map(_._2)).foldRight[Block](k(bodyAndRestInLam)): + case ((a, tmp), rest) => applyResult2(a) { r => Assign(tmp, r, rest) } + + case CtorFinalDest.Sel(s) => + val selFieldName = s.getResult match { case Select(p, nme) => nme } + val idx = d.getClsFields(c).indexWhere(s => s.id === selFieldName) + k(args(idx)) + + r match + case call@Call(f, args) if d.filteredCtorDests.isDefinedAt(call.uid) => + handleCallLike(f, args.map { case Arg(false, value) => value }, call.uid) + case ins@Instantiate(cls, args) if d.filteredCtorDests.isDefinedAt(ins.uid) => + handleCallLike(cls, args, ins.uid) + case _ => super.applyResult2(r)(k) + + def handleObjFusing(objCallExprUid: ResultId, objClsSym: ModuleSymbol) = + // must be a pat mat on objects; no support for selection on objects yet + val CtorFinalDest.Match(scrut, expr, sels, selsMap) = d.filteredCtorDests(objCallExprUid): @unchecked + val body = expr.arms.find{ case (Case.Cls(m, _) -> body) => m === objClsSym }.map(_._2).orElse(expr.dflt).get + matchArms.getOrElseUpdate(scrut, expr, objClsSym, Set.empty, Map.empty) + + override def applyPath(p: Path): Path = p match + // a selection which is a consumer on its own + case s@Select(p, nme) if d.rewritingSelConsumers.contains(s.uid) => applyPath(p) + + // a selection inside a fusing match that needs to be replaced by pre-computed symbols + case s@Select(p, nme) if replaceSelInfo.get(s.uid).isDefined => Value.Ref(replaceSelInfo(s.uid)) + + case s@Select(p, nme) => s.symbol.flatMap(_.asObj) match + // a fusing object constructor + case Some(obj) if d.filteredCtorDests.isDefinedAt(s.uid) => handleObjFusing(s.uid, obj) + case _ => super.applyPath(s) + + case v: Value => applyValue(v) + case _ => super.applyPath(p) + + override def applyValue(v: Value): Value = v match + case r@Value.Ref(l) => l.asObj match + case None => r + case Some(obj) if d.filteredCtorDests.isDefinedAt(r.uid) => handleObjFusing(r.uid, obj) + case _ => super.applyValue(v) + case _ => super.applyValue(v) diff --git a/hkmc2/shared/src/test/mlscript/backlog/ToTriage.mls b/hkmc2/shared/src/test/mlscript/backlog/ToTriage.mls index 2a149abc4..d66749822 100644 --- a/hkmc2/shared/src/test/mlscript/backlog/ToTriage.mls +++ b/hkmc2/shared/src/test/mlscript/backlog/ToTriage.mls @@ -390,6 +390,13 @@ Foo(1) // ——— ——— ——— +if x > 0 + then A + then B +//│ /!!!\ Uncaught error: scala.MatchError: InfixApp(App(Ident(>),Tup(List(Ident(x), IntLit(0)))),keyword 'then',Ident(A)) (of class hkmc2.syntax.Tree$InfixApp) + +// ——— ——— ——— + data class Foo(...args) Foo(1, 2, 3).args @@ -398,31 +405,31 @@ Foo(1, 2, 3).args :todo if Foo(1, 2, 3) is Foo(...args) then args //│ ╔══[ERROR] the constructor does not take any arguments but found 1 -//│ ║ l.399: if Foo(1, 2, 3) is Foo(...args) then args +//│ ║ l.406: if Foo(1, 2, 3) is Foo(...args) then args //│ ╙── ^^^^^^^^^^^^ //│ ╔══[ERROR] Name not found: args -//│ ║ l.399: if Foo(1, 2, 3) is Foo(...args) then args +//│ ║ l.406: if Foo(1, 2, 3) is Foo(...args) then args //│ ╙── ^^^^ //│ /!!!\ Uncaught error: scala.NotImplementedError: an implementation is missing if Foo(1, 2, 3) is Foo(a, b, c) then [a, b, c] //│ ╔══[ERROR] the constructor does not take any arguments but found 3 -//│ ║ l.408: if Foo(1, 2, 3) is Foo(a, b, c) then [a, b, c] +//│ ║ l.415: if Foo(1, 2, 3) is Foo(a, b, c) then [a, b, c] //│ ╙── ^^^^^^^^^^^^ //│ ╔══[ERROR] Name not found: a -//│ ║ l.408: if Foo(1, 2, 3) is Foo(a, b, c) then [a, b, c] +//│ ║ l.415: if Foo(1, 2, 3) is Foo(a, b, c) then [a, b, c] //│ ╙── ^ //│ ╔══[ERROR] Name not found: b -//│ ║ l.408: if Foo(1, 2, 3) is Foo(a, b, c) then [a, b, c] +//│ ║ l.415: if Foo(1, 2, 3) is Foo(a, b, c) then [a, b, c] //│ ╙── ^ //│ /!!!\ Uncaught error: scala.NotImplementedError: an implementation is missing if Foo(1, 2, 3) is Foo(arg) then arg //│ ╔══[ERROR] the constructor does not take any arguments but found 1 -//│ ║ l.420: if Foo(1, 2, 3) is Foo(arg) then arg +//│ ║ l.427: if Foo(1, 2, 3) is Foo(arg) then arg //│ ╙── ^^^^^^^^ //│ ╔══[ERROR] Name not found: arg -//│ ║ l.420: if Foo(1, 2, 3) is Foo(arg) then arg +//│ ║ l.427: if Foo(1, 2, 3) is Foo(arg) then arg //│ ╙── ^^^ //│ /!!!\ Uncaught error: scala.NotImplementedError: an implementation is missing diff --git a/hkmc2/shared/src/test/mlscript/deforest/append.mls b/hkmc2/shared/src/test/mlscript/deforest/append.mls new file mode 100644 index 000000000..7f1effc58 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/deforest/append.mls @@ -0,0 +1,74 @@ +:js +:deforest + +//│ No fusion opportunity + + +object Nil +data class (::) Cons(h, t) + + + +fun append1(xs1, ys) = if xs1 is + h :: t then h :: append1(t, ys) + Nil then ys +fun append2(xs2, ys) = if xs2 is + h :: t then h :: append2(t, ys) + Nil then ys +fun appendThree(xs, ys, zs) = + append1(append2(xs, ys), zs) +appendThree of + id(1 :: 2 :: Nil) + id(3 :: 4 :: Nil) + id(5 :: 6 :: Nil) +//│ = Cons(1, Cons(2, Cons(3, Cons(4, Cons(5, Cons(6, Nil)))))) +//│ No fusion opportunity + +// maybe the fusion target for the previous program +fun appendReified(ys, zs) = if ys is + h :: t then h :: appendReified(t, zs) + Nil then zs +fun append1(ys, zs) = ys(zs) +fun append2(xs, ys) = + if xs is + h :: t then + zs => h :: append1(append2(t, ys), zs) // normal fusion + Nil then zs => appendReified(ys, zs) // reified +fun test(xs, ys, zs) = append1(append2(xs, ys), zs) +test of + id(1 :: 2 :: Nil) + id(3 :: 4 :: Nil) + id(5 :: 6 :: Nil) +//│ = Cons(1, Cons(2, Cons(3, Cons(4, Cons(5, Cons(6, Nil)))))) +//│ No fusion opportunity + + +fun idList(l) = if l is + h :: t then h :: idList(t) + Nil then Nil +fun append(xs, ys) = if xs is + h :: t then h :: append(t, ys) + Nil then idList(ys) +fun appendThree(xs, ys, zs) = + append(append(xs, ys), zs) +appendThree of + id(1 :: 2 :: Nil) + id(3 :: 4 :: Nil) + id(5 :: 6 :: Nil) +//│ = Cons(1, Cons(2, Cons(3, Cons(4, Cons(5, Cons(6, Nil)))))) +//│ No fusion opportunity + + + +fun append(xs, ys) = if xs is + h :: t then h :: append(t, ys) + Nil then idList(ys) +fun concat(lss) = if lss is + hh :: tt then append(hh, concat(tt)) + Nil then Nil +concat of id of + (1 :: 2 :: Nil) :: + (3 :: 4 :: Nil) :: + (5 :: 6 :: Nil) :: Nil +//│ = Cons(1, Cons(2, Cons(3, Cons(4, Cons(5, Cons(6, Nil)))))) +//│ No fusion opportunity diff --git a/hkmc2/shared/src/test/mlscript/deforest/determinism.mls b/hkmc2/shared/src/test/mlscript/deforest/determinism.mls new file mode 100644 index 000000000..7c9d6853f --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/deforest/determinism.mls @@ -0,0 +1,164 @@ +:js + +data class A(a) +data class B(a, b, c, d, e) + +:deforest +:sjs +if B(1,2,3,4,5) is + B(a,b,c,d,e) then 0 +//│ JS (unsanitized): +//│ let scrut, param0, param1, param2, param3, param4, a, b, c, d, e; +//│ scrut = B1(1, 2, 3, 4, 5); +//│ if (scrut instanceof B1.class) { +//│ param0 = scrut.a; +//│ param1 = scrut.b; +//│ param2 = scrut.c; +//│ param3 = scrut.d; +//│ param4 = scrut.e; +//│ a = param0; +//│ b = param1; +//│ c = param2; +//│ d = param3; +//│ e = param4; +//│ 0 +//│ } else { +//│ throw new this.Error("match error"); +//│ } +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ ==== JS (deforested): ==== +//│ let scrut, _deforest_B_e, _deforest_B_d, _deforest_B_c, _deforest_B_b, _deforest_B_a; +//│ _deforest_B_a = 1; +//│ _deforest_B_b = 2; +//│ _deforest_B_c = 3; +//│ _deforest_B_d = 4; +//│ _deforest_B_e = 5; +//│ scrut = () => { +//│ let param0, param1, param2, param3, param4, a, b, c, d, e; +//│ param0 = _deforest_B_a; +//│ param1 = _deforest_B_b; +//│ param2 = _deforest_B_c; +//│ param3 = _deforest_B_d; +//│ param4 = _deforest_B_e; +//│ a = param0; +//│ b = param1; +//│ c = param2; +//│ d = param3; +//│ e = param4; +//│ return 0 +//│ }; +//│ runtime.safeCall(scrut()) +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< +//│ = 0 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 0 +//│ 1 fusion opportunities: +//│ B --match--> `if scrut is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + +:sjs +:deforest +fun c2(x) = if x is + A(A(A(A(A(A(a)))))) then a +c2 of A(A(A(A(A(A(1)))))) +//│ JS (unsanitized): +//│ let c2, tmp, tmp1, tmp2, tmp3, tmp4, tmp5; +//│ c2 = function c2(x) { +//│ let param01, param02, param03, param04, param05, param06, a1; +//│ if (x instanceof A1.class) { +//│ param01 = x.a; +//│ if (param01 instanceof A1.class) { +//│ param02 = param01.a; +//│ if (param02 instanceof A1.class) { +//│ param03 = param02.a; +//│ if (param03 instanceof A1.class) { +//│ param04 = param03.a; +//│ if (param04 instanceof A1.class) { +//│ param05 = param04.a; +//│ if (param05 instanceof A1.class) { +//│ param06 = param05.a; +//│ a1 = param06; +//│ return a1 +//│ } else { +//│ throw new globalThis.Error("match error"); +//│ } +//│ } else { +//│ throw new globalThis.Error("match error"); +//│ } +//│ } else { +//│ throw new globalThis.Error("match error"); +//│ } +//│ } else { +//│ throw new globalThis.Error("match error"); +//│ } +//│ } else { +//│ throw new globalThis.Error("match error"); +//│ } +//│ } else { +//│ throw new globalThis.Error("match error"); +//│ } +//│ }; +//│ tmp = A1(1); +//│ tmp1 = A1(tmp); +//│ tmp2 = A1(tmp1); +//│ tmp3 = A1(tmp2); +//│ tmp4 = A1(tmp3); +//│ tmp5 = A1(tmp4); +//│ c2(tmp5) +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ ==== JS (deforested): ==== +//│ let c2, tmp, tmp1, tmp2, tmp3, tmp4, tmp5, _deforest_A_a, _deforest_A_a1, _deforest_A_a2, _deforest_A_a3, _deforest_A_a4, _deforest_A_a5; +//│ c2 = function c2(x) { +//│ return runtime.safeCall(x()) +//│ }; +//│ _deforest_A_a5 = 1; +//│ tmp = () => { +//│ let param01, a1; +//│ param01 = _deforest_A_a5; +//│ a1 = param01; +//│ return a1 +//│ }; +//│ _deforest_A_a4 = tmp; +//│ tmp1 = () => { +//│ let param01; +//│ param01 = _deforest_A_a4; +//│ return runtime.safeCall(param01()) +//│ }; +//│ _deforest_A_a3 = tmp1; +//│ tmp2 = () => { +//│ let param01; +//│ param01 = _deforest_A_a3; +//│ return runtime.safeCall(param01()) +//│ }; +//│ _deforest_A_a2 = tmp2; +//│ tmp3 = () => { +//│ let param01; +//│ param01 = _deforest_A_a2; +//│ return runtime.safeCall(param01()) +//│ }; +//│ _deforest_A_a1 = tmp3; +//│ tmp4 = () => { +//│ let param01; +//│ param01 = _deforest_A_a1; +//│ return runtime.safeCall(param01()) +//│ }; +//│ _deforest_A_a = tmp4; +//│ tmp5 = () => { +//│ let param01; +//│ param01 = _deforest_A_a; +//│ return runtime.safeCall(param01()) +//│ }; +//│ c2(tmp5) +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< +//│ = 1 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 1 +//│ 6 fusion opportunities: +//│ A --match--> `if param0 is ...` +//│ A --match--> `if param0 is ...` +//│ A --match--> `if param0 is ...` +//│ A --match--> `if param0 is ...` +//│ A --match--> `if param0 is ...` +//│ A --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< diff --git a/hkmc2/shared/src/test/mlscript/deforest/imperative.mls b/hkmc2/shared/src/test/mlscript/deforest/imperative.mls new file mode 100644 index 000000000..45894e3f2 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/deforest/imperative.mls @@ -0,0 +1,50 @@ +:js +:deforest + +//│ No fusion opportunity + +object A +object B +// object C +// class AA(aa) +// class BB(bb) +// object None +// class Some(value) + +// * Not fused: two `x is A` consumers +let x = if true then A else B +fun foo(x) = + if x is A do print(123) + if x is + A then 1 + B then 2 +//│ x = A +//│ No fusion opportunity + +// * We could make it work. But it's a special case that's probably not very important +let x = if true then A else B +fun foo(x) = + if x is A do print(123) + if x is B do print(456) +//│ x = A +//│ No fusion opportunity + +fun foo(k, x) = + if x === 0 do k(A) + k(if x > 0 + then A + else B) +fun bar(v) = if v is + A then 1 + B then 2 +foo(bar, 123) +//│ = 1 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 1 +//│ 3 fusion opportunities: +//│ A --match--> `if v is ...` +//│ A --match--> `if v is ...` +//│ B --match--> `if v is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + diff --git a/hkmc2/shared/src/test/mlscript/deforest/listComprehension.mls b/hkmc2/shared/src/test/mlscript/deforest/listComprehension.mls new file mode 100644 index 000000000..e22cabd87 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/deforest/listComprehension.mls @@ -0,0 +1,96 @@ +:js +:deforest + +//│ No fusion opportunity +object Nil +data class (::) Cons(h, t) +data class Pair(a, b) +object A +object B + +fun zip(xs_zip, ys_zip) = if + xs_zip is x :: xt and ys_zip is y :: yt then Pair(x, y) :: zip(xt, yt) + else Nil +fun enumFromTo(a, b) = if a < (b + 1) then a :: enumFromTo(a + 1, b) else Nil +fun test() = + fun lscomp1(ls) = if ls is + Pair(x, y1) :: t then + fun lscomp2(ls2) = if ls2 is + Pair(y2, z) :: t2 and + y1 == y2 then Pair(x, z) :: lscomp2(t2) + else lscomp2(t2) + else lscomp1(t) + lscomp2(zip(enumFromTo(x, x + 2), enumFromTo(y1, y1 + 1))) + else Nil + lscomp1(zip(enumFromTo(1, 3), enumFromTo(2, 4))) +test() +//│ = Cons(Pair(1, 3), Cons(Pair(2, 4), Cons(Pair(3, 5), Nil))) +//│ No fusion opportunity + + +fun append(xs, ys) = if xs is + h :: t then h :: append(t, ys) + Nil then ys +fun concatMap(f, ls) = if ls is + h :: t then append(f(h), concatMap(f, t)) + Nil then Nil +fun test() = + fun f1(a1) = if a1 is + Pair(a, b) then + fun f2(a2) = if a2 is Pair(c, d) then Pair(a, c) :: Nil else Nil + concatMap(f2, Pair(1, 3) :: Pair(2, 3) :: Pair(a, b) :: Nil) + else Nil + concatMap(f1, Pair(5, 10) :: Nil) +test() +//│ = Cons(Pair(5, 1), Cons(Pair(5, 2), Cons(Pair(5, 5), Nil))) +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = Cons(Pair(5, 1), Cons(Pair(5, 2), Cons(Pair(5, 5), Nil))) +//│ 6 fusion opportunities: +//│ Cons --match--> `if ls is ...` +//│ Cons --match--> `if ls is ...` +//│ Cons --match--> `if ls is ...` +//│ Cons --match--> `if ls is ...` +//│ Nil --match--> `if ls is ...` +//│ Nil --match--> `if ls is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + +fun lscomp1(ls1, ls2) = if ls1 is + h1 :: t1 then lscomp2(ls2, h1, t1) + Nil then Nil +fun lscomp2(ls2, h1, t1) = if ls2 is + h2 :: t2 then Pair(h1, h2) :: lscomp2(t2, h1, t1) + Nil then lscomp1(t1, ls2) +lscomp1(1 :: 2 :: Nil, 3 :: 4 :: Nil) +//│ = Cons(Pair(1, 3), Cons(Pair(1, 4), Nil)) +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = Cons(Pair(1, 3), Cons(Pair(1, 4), Nil)) +//│ 6 fusion opportunities: +//│ Cons --match--> `if ls1 is ...` +//│ Cons --match--> `if ls1 is ...` +//│ Cons --match--> `if ls2 is ...` +//│ Cons --match--> `if ls2 is ...` +//│ Nil --match--> `if ls1 is ...` +//│ Nil --match--> `if ls2 is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + +fun test(k, ls) = + fun lscomp(ls) = if ls is + h :: t then (h + 1) :: lscomp(t) + Nil then Nil + if k is + A then lscomp(ls) + B then Nil +test(A, 1 :: 2 :: 3 :: Nil) +//│ = Cons(2, Cons(3, Cons(4, Nil))) +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = Cons(2, Cons(3, Cons(4, Nil))) +//│ 5 fusion opportunities: +//│ A --match--> `if k is ...` +//│ Cons --match--> `if ls is ...` +//│ Cons --match--> `if ls is ...` +//│ Cons --match--> `if ls is ...` +//│ Nil --match--> `if ls is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< diff --git a/hkmc2/shared/src/test/mlscript/deforest/nestedMatch.mls b/hkmc2/shared/src/test/mlscript/deforest/nestedMatch.mls new file mode 100644 index 000000000..4d320f26b --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/deforest/nestedMatch.mls @@ -0,0 +1,402 @@ +:js +:deforest + +//│ No fusion opportunity +object A +object B +object C +data class AA(x) +data class BB(x) +data class CC(x) + +object Nil +data class Cons(h, t) + +object None +data class Some(x) + + +fun f(x, y) = + let t = if x is + AA(a) then + let m = if a is + AA(x) then x + m + 9 + BB(b) then b + t + y +fun g(x) = if x is + AA then 0 +let a = AA(AA(3)) +f(a, 2) + g(a) + f(BB(3), 2) + f(AA(AA(4)), 5) +//│ = 37 +//│ a = AA(AA(3)) +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 37 +//│ 2 fusion opportunities: +//│ AA --match--> `if a is ...` +//│ AA --match--> `if a is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + +fun f(x, y) = + let n = if x is + AA(a) then + if a is + AA(m) then m + n + 2 + y +f(AA(AA(3)), 9) + f(AA(AA(4)), 10) +//│ = 30 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 30 +//│ 4 fusion opportunities: +//│ AA --match--> `if a is ...` +//│ AA --match--> `if a is ...` +//│ AA --match--> `if x is ...` +//│ AA --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + +fun f(x) = + let n = if x is + AA(a) then + if a is + AA(m) then m + n + 2 +f(AA(AA(3))) +//│ = 5 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 5 +//│ 2 fusion opportunities: +//│ AA --match--> `if a is ...` +//│ AA --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + +fun f(x) = + let t = if x is + AA(AA(a)) then a + t + 3 +f(AA(AA(A))) + f(AA(AA(A))) +//│ = "A3A3" +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = "A3A3" +//│ 4 fusion opportunities: +//│ AA --match--> `if param0 is ...` +//│ AA --match--> `if param0 is ...` +//│ AA --match--> `if x is ...` +//│ AA --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + +fun c2(x) = if x is + AA(AA(a)) then a +c2(AA(AA(0))) +//│ = 0 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 0 +//│ 2 fusion opportunities: +//│ AA --match--> `if param0 is ...` +//│ AA --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + +fun f(a) = if a is + AA(BB(B)) then 3 +f(AA(BB(B))) +//│ = 3 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 3 +//│ 3 fusion opportunities: +//│ AA --match--> `if a is ...` +//│ B --match--> `if param0 is ...` +//│ BB --match--> `if param0 is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + +fun test(x, y, z, i) = if x is + AA(a) then + let m = if y is + AA(a1) then a1 + i + let n = if z is + BB(a2) then a2 - i + m + n +test(AA(1), AA(2), BB(3), 4) + test(AA(1), AA(2), BB(3), 4) +//│ = 10 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 10 +//│ 6 fusion opportunities: +//│ AA --match--> `if x is ...` +//│ AA --match--> `if x is ...` +//│ AA --match--> `if y is ...` +//│ AA --match--> `if y is ...` +//│ BB --match--> `if z is ...` +//│ BB --match--> `if z is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + +// only the match in the middle is fused +fun f1(x) = if x is + AA(x1) then + if x1 is + BB(x2) then + if x2 is + CC(x3) then x3 +fun f2(x) = if x is AA then "f2" +fun f3(x) = if x is CC then "f3" + x.x +fun aa(x) = AA(x) +let cc = CC("cc") +f1(aa(BB(cc))) + f2(aa(BB(CC(0)))) + f3(cc) +//│ = "ccf2f3cc" +//│ cc = CC("cc") +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = "ccf2f3cc" +//│ 2 fusion opportunities: +//│ BB --match--> `if x1 is ...` +//│ BB --match--> `if x1 is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + +fun c(x, y) = if x is + AA then + let t = if y is + A then 2 + t + x.x +let y = A +c(AA(2), y) +//│ = 4 +//│ y = A +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 4 +//│ 2 fusion opportunities: +//│ A --match--> `if y is ...` +//│ AA --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + +fun c(x, y) = if x is + AA then + let t = if y is + A then 2 + x.x + t + x.x +fun c2(y) = if y is A then 3 +let y = A +c(AA(2), y) + c2(y) +//│ = 9 +//│ y = A +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 9 +//│ 1 fusion opportunities: +//│ AA --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + +// need to include computations of `rest` from more than one levels of parent matches +fun test(x) = + let t = if x is + AA(AA(AA(a))) then a + AA then "3" + t + "5" +fun f(a) = if a is + AA(AA) then "0" +let p = AA(AA(AA("10"))) +test(p) + f(p) + test(AA(A)) +//│ = "105035" +//│ p = AA(AA(AA("10"))) +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = "105035" +//│ 1 fusion opportunities: +//│ AA --match--> `if param0 is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + + +fun f(x, y) = + let tmp = BB(y + 1) + if tmp is + BB then + let m = if x is + AA then + if AA(2) is + AA(yf) then yf + if AA(3) is + AA then m + tmp.x +let aa = AA(3) +f(aa, 3) + f(aa, 4) +//│ = 13 +//│ aa = AA(3) +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 13 +//│ 4 fusion opportunities: +//│ AA --match--> `if scrut is ...` +//│ AA --match--> `if scrut is ...` +//│ AA --match--> `if x is ...` +//│ BB --match--> `if tmp is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + +fun test(n, p, q) = + n + + if AA(0) is + AA then + let k = if p is + AA(x) then x + if q is + AA(y) then k + y +test(3, AA(2), AA(3)) + test(3, AA(2), AA(3)) +//│ = 16 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 16 +//│ 5 fusion opportunities: +//│ AA --match--> `if p is ...` +//│ AA --match--> `if p is ...` +//│ AA --match--> `if q is ...` +//│ AA --match--> `if q is ...` +//│ AA --match--> `if scrut is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + +fun test(n, p, q, a) = + let o + if a is + A then + let k + if p is + AA(x) then + k = x + if q is + AA(y) then + o = k + y + o + n +let a = A +fun c(a) = if a is A then 0 +test(1, AA(2), AA(3), a) + test(2, AA(2), AA(3), a) + c(a) +//│ = 13 +//│ a = A +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 13 +//│ 4 fusion opportunities: +//│ AA --match--> `if p is ...` +//│ AA --match--> `if p is ...` +//│ AA --match--> `if q is ...` +//│ AA --match--> `if q is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + +fun test(x, y, z, i) = + let k = if x is + AA(a) then + let m = if y is + AA(a1) then a1 + i + let n = if z is + BB(a2) then a2 - i + m + n + k + i +test(AA(1), AA(2), BB(3), 4) + test(AA(1), AA(2), BB(3), 4) +//│ = 18 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 18 +//│ 6 fusion opportunities: +//│ AA --match--> `if x is ...` +//│ AA --match--> `if x is ...` +//│ AA --match--> `if y is ...` +//│ AA --match--> `if y is ...` +//│ BB --match--> `if z is ...` +//│ BB --match--> `if z is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + +fun test(x, y, z, i) = + let k = if x is + AA(a) then + let m = if y is + AA(a1) then a1 + i + BB(b2) then b2 - i + let n = if z is + BB(a2) then a2 - i + m + n + k + i +test(AA(1), AA(2), BB(3), 4) + test(AA(1), BB(2), BB(3), 4) +//│ = 10 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 10 +//│ 6 fusion opportunities: +//│ AA --match--> `if x is ...` +//│ AA --match--> `if x is ...` +//│ AA --match--> `if y is ...` +//│ BB --match--> `if y is ...` +//│ BB --match--> `if z is ...` +//│ BB --match--> `if z is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + +fun test(x) = + let n = if x is + AA(BB(b)) then b + AA(CC(c)) then c + else 0 + n + 3 +fun c(x) = if x is AA then 0 +fun p(x) = AA(x) +test(p(BB(3))) + test(p(CC(3))) + c(p(A)) +//│ = 12 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 12 +//│ 3 fusion opportunities: +//│ A --match--> `if param0 is ...` +//│ BB --match--> `if param0 is ...` +//│ CC --match--> `if param0 is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + + + +fun f(x, y, z, k) = + let tmp = if z then BB(y + 1) else BB(y - 1) + tmp.x + + if tmp is + BB then + let m = if x is + AA then + if AA(2) is + AA(yf) then yf + if AA(3) is + AA then m + k +f(AA(3), 3, true, 10) + f(AA(5), 4, false, 20) +//│ = 41 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 41 +//│ 4 fusion opportunities: +//│ AA --match--> `if scrut is ...` +//│ AA --match--> `if scrut is ...` +//│ AA --match--> `if x is ...` +//│ AA --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + +fun test(x) = + let t = if x is + AA(AA(AA(a))) then a + t + 5 +fun f(a) = if a is + AA(AA) then 0 +let p = AA(AA(AA(10))) +test(p) + f(p) +//│ = 15 +//│ p = AA(AA(AA(10))) +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 15 +//│ 1 fusion opportunities: +//│ AA --match--> `if param0 is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< diff --git a/hkmc2/shared/src/test/mlscript/deforest/recursive.mls b/hkmc2/shared/src/test/mlscript/deforest/recursive.mls new file mode 100644 index 000000000..45434baa5 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/deforest/recursive.mls @@ -0,0 +1,299 @@ +:js +:deforest + +//│ No fusion opportunity + + +object Nil +data class (::) Cons(h, t) +object None +data class Some(v) +object A +object B +data class T(n, l, r) +object L + +fun mk(n) = if n < 0 then Nil else n :: mk(n - 1) +fun map(f, ls_map) = if ls_map is + h :: t then f(h) :: map(f, t) + Nil then Nil +fun map1(f, ls_map1) = if ls_map1 is + h :: t then f(h) :: map1(f, t) + Nil then Nil +fun incr(x) = x + 1 +fun double(x) = x * 2 +fun test(ls) = map1(incr, map(double, ls)) +test(id(mk(5))) +//│ = Cons(11, Cons(9, Cons(7, Cons(5, Cons(3, Cons(1, Nil)))))) +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = Cons(11, Cons(9, Cons(7, Cons(5, Cons(3, Cons(1, Nil)))))) +//│ 2 fusion opportunities: +//│ Cons --match--> `if ls_map1 is ...` +//│ Nil --match--> `if ls_map1 is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + +fun enumFromTo(a, b) = if a < b then a :: enumFromTo(a + 1, b) else Nil +fun map(ls) = if ls is + Nil then Nil + h :: t then (h + 4) :: map(t) +map(enumFromTo(1, 4)) +//│ = Cons(5, Cons(6, Cons(7, Nil))) +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = Cons(5, Cons(6, Cons(7, Nil))) +//│ 2 fusion opportunities: +//│ Cons --match--> `if ls is ...` +//│ Nil --match--> `if ls is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + +fun enumFromTo(a, b) = if a < b then a :: enumFromTo(a + 1, b) else Nil +fun sum(ls) = if ls is + Nil then 0 + h :: t then h + sum(t) +sum(enumFromTo(1,10)) +//│ = 45 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 45 +//│ 2 fusion opportunities: +//│ Cons --match--> `if ls is ...` +//│ Nil --match--> `if ls is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + +fun enumFromTo(a, b) = if a < b then a :: enumFromTo(a + 1, b) else Nil +fun map(f, ls) = + if ls is + Nil then Nil + h :: t then f(h) :: map(f, t) +map(x => x + 4, enumFromTo(1, 4)) +//│ = Cons(5, Cons(6, Cons(7, Nil))) +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = Cons(5, Cons(6, Cons(7, Nil))) +//│ 2 fusion opportunities: +//│ Cons --match--> `if ls is ...` +//│ Nil --match--> `if ls is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + +fun enumFromTo(a, b) = if a < b then a :: enumFromTo(a + 1, b) else Nil +fun map(f, ls) = + (if ls is + Nil then f => Nil + h :: t then f => f(h) :: map(f, t) + )(f) +map(x => x + 4, enumFromTo(1, 4)) +//│ = Cons(5, Cons(6, Cons(7, Nil))) +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = Cons(5, Cons(6, Cons(7, Nil))) +//│ 2 fusion opportunities: +//│ Cons --match--> `if ls is ...` +//│ Nil --match--> `if ls is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + +fun enumFromTo(a, b) = if a < b then a :: enumFromTo(a + 1, b) else Nil +fun sum(ls, a) = if ls is + Nil then a + h :: t then sum(t, h + a) +sum(enumFromTo(1, 10), 0) +//│ = 45 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 45 +//│ 2 fusion opportunities: +//│ Cons --match--> `if ls is ...` +//│ Nil --match--> `if ls is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + +fun mk(n) = if n < 0 then Nil else n :: mk(n - 1) +fun incr(x) = x + 1 +fun map(f, xs_map) = if xs_map is + Nil then Nil + x :: xs then f(x) :: map(f, xs) +fun rev(xs_rev, acc) = if xs_rev is + x :: xs then rev(xs, x :: acc) + Nil then acc +fun test(xs) = map(incr, rev(xs, Nil)) +test(id(mk(5))) +//│ = Cons(1, Cons(2, Cons(3, Cons(4, Cons(5, Cons(6, Nil)))))) +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = Cons(1, Cons(2, Cons(3, Cons(4, Cons(5, Cons(6, Nil)))))) +//│ 2 fusion opportunities: +//│ Cons --match--> `if xs_map is ...` +//│ Nil --match--> `if xs_map is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + +fun mk(n) = if n < 0 then Nil else n :: mk(n - 1) +fun incr(x) = x + 1 +fun map(f, xs_map) = if xs_map is + Nil then Nil + x :: xs then f(x) :: map(f, xs) +fun rev(xs_rev, acc) = if xs_rev is + x :: xs then rev(xs, x :: acc) + Nil then acc +fun test(xs) = rev(map(incr, xs), Nil) +test(id(mk(3))) +//│ = Cons(1, Cons(2, Cons(3, Cons(4, Nil)))) +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = Cons(1, Cons(2, Cons(3, Cons(4, Nil)))) +//│ 2 fusion opportunities: +//│ Cons --match--> `if xs_rev is ...` +//│ Nil --match--> `if xs_rev is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + +data class Pair(a, b) + + +fun pair_up(xs) = + if xs is + x :: xss then + if xss is + y :: xs then Pair(x, y) :: pair_up(xs) + else Nil + else Nil +fun mk(n) = if n > 0 then (n - 1) :: n :: (n + 1) :: mk(n - 1) else Nil +fun test(x) = pair_up(mk(x)) +test(4) +//│ = Cons(Pair(3, 4), Cons(Pair(5, 2), Cons(Pair(3, 4), Cons(Pair(1, 2), Cons(Pair(3, 0), Cons(Pair(1, 2), Nil)))))) +//│ No fusion opportunity + + + +fun mk(n) = if n > 0 then n :: mk(n - 1) else Nil +fun mk2d(n) = if n > 0 then mk(n) :: mk2d(n - 1) else Nil +fun append(xs, ys) = if xs is + Nil then ys + x :: xs then x :: append(xs, ys) +fun sum(ls_sum) = if ls_sum is + Nil then 0 + x :: xs then x + sum(xs) +fun flatten(ls_flatten) = if ls_flatten is + Nil then Nil + x :: xs then append(x, flatten(xs)) +fun test(ls) = sum(flatten(ls)) +test(id(mk2d(4))) +//│ = 20 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 20 +//│ 2 fusion opportunities: +//│ Cons --match--> `if ls_sum is ...` +//│ Nil --match--> `if ls_sum is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + +fun count(c, xs) = if xs is + h :: t then count(c + 1, t) + Nil then c +fun rev(a, ys) = if ys is + h :: t then rev(h :: a, t) + Nil then a +count(0, rev(Nil, 1 :: 2 :: Nil)) +//│ = 2 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 2 +//│ 5 fusion opportunities: +//│ Cons --match--> `if xs is ...` +//│ Cons --match--> `if ys is ...` +//│ Cons --match--> `if ys is ...` +//│ Nil --match--> `if xs is ...` +//│ Nil --match--> `if ys is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + +fun last(ls) = if ls is + h :: t and t is + Nil then Some(h) + _ :: _ then last(t) + Nil then None +let p = 1 :: 2 :: Nil +last(p) +//│ = Some(2) +//│ p = Cons(1, Cons(2, Nil)) +//│ No fusion opportunity + + + +fun map(f, xs_map) = if xs_map is + Nil then Nil + x :: xs then f(x) :: map(f, xs) +map(x => if x is A then 1 else 0, A :: B :: Nil) +//│ = Cons(1, Cons(0, Nil)) +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = Cons(1, Cons(0, Nil)) +//│ 5 fusion opportunities: +//│ A --match--> `if x is ...` +//│ B --match--> `if x is ...` +//│ Cons --match--> `if xs_map is ...` +//│ Cons --match--> `if xs_map is ...` +//│ Nil --match--> `if xs_map is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + +fun c(x) = if x is + T(n, l, r) then T of + if n is + A then 0 + B then 1 + c(l) + c(r) + L then L +c(T(A, T(B, L, L), T(A, L, L))) +//│ = T(0, T(1, L, L), T(0, L, L)) +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = T(0, T(1, L, L), T(0, L, L)) +//│ 10 fusion opportunities: +//│ A --match--> `if n is ...` +//│ A --match--> `if n is ...` +//│ B --match--> `if n is ...` +//│ L --match--> `if x is ...` +//│ L --match--> `if x is ...` +//│ L --match--> `if x is ...` +//│ L --match--> `if x is ...` +//│ T --match--> `if x is ...` +//│ T --match--> `if x is ...` +//│ T --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + +fun max(ms, m) = if ms is + h :: t and + h > m then max(t, h) + else max(t, m) + Nil then m +max(3 :: 2 :: 4 :: 1 :: 0 :: Nil, 0) +//│ = 4 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 4 +//│ 6 fusion opportunities: +//│ Cons --match--> `if ms is ...` +//│ Cons --match--> `if ms is ...` +//│ Cons --match--> `if ms is ...` +//│ Cons --match--> `if ms is ...` +//│ Cons --match--> `if ms is ...` +//│ Nil --match--> `if ms is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + +fun filter(ls, f) = if ls is + h :: t and + f(h) then h :: filter(t, f) + else filter(t, f) + Nil then Nil +fun last(ls_last) = + fun go(a, ls_go) = if ls_go is + Nil then a + h :: t then go(h, t) + if ls_last is + h :: t then Some(go(h, t)) + Nil then None +fun lastFilter(ls, f) = last(filter(ls, f)) +lastFilter(id(1 :: 2 :: 3 :: 4 :: 5 :: Nil), x => (x % 2 == 0)) +//│ = Some(4) +//│ No fusion opportunity + + diff --git a/hkmc2/shared/src/test/mlscript/deforest/selectionsInNestedMatch.mls b/hkmc2/shared/src/test/mlscript/deforest/selectionsInNestedMatch.mls new file mode 100644 index 000000000..0c379fab0 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/deforest/selectionsInNestedMatch.mls @@ -0,0 +1,140 @@ +:js +:deforest + +//│ No fusion opportunity +object A +object B +object C +data class AA(x) +data class BB(x) + +object Nil +data class Cons(h, t) + +object None +data class Some(x) + + + + +fun c(x) = if x is + AA then + if A is + A then x.x +c(AA(2)) +//│ = 2 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 2 +//│ 2 fusion opportunities: +//│ A --match--> `if scrut is ...` +//│ AA --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + +fun c(x, y) = if x is + AA then + if y is + A then x.x +c(AA(2), A) +//│ = 2 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 2 +//│ 2 fusion opportunities: +//│ A --match--> `if y is ...` +//│ AA --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + +fun c(x, y) = if x is + AA then + if y is + A then x.x +fun p() = AA(2) +c(p(), A) +//│ = 2 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 2 +//│ 2 fusion opportunities: +//│ A --match--> `if y is ...` +//│ AA --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + + +fun c(x, y) = if x is + AA then + let a = if y is + A then 1 + a + x.x +fun p() = AA(2) +c(p(), A) +//│ = 3 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 3 +//│ 2 fusion opportunities: +//│ A --match--> `if y is ...` +//│ AA --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + + +fun f(x) = if x is + AA(AA(a)) then g(a) +fun g(x) = if x is + AA(b) then f(b) + A then 42 +f(AA(AA(AA(AA(AA(A)))))) +//│ = 42 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 42 +//│ 6 fusion opportunities: +//│ A --match--> `if x is ...` +//│ AA --match--> `if param0 is ...` +//│ AA --match--> `if param0 is ...` +//│ AA --match--> `if x is ...` +//│ AA --match--> `if x is ...` +//│ AA --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + +fun f(x) = if x is + AA(AA(a)) then a +f(AA(AA(A))) +//│ = A +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = A +//│ 2 fusion opportunities: +//│ AA --match--> `if param0 is ...` +//│ AA --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + +fun f(x) = if x is + AA(AA(a)) then a +fun p() = AA(AA(A)) +f(p()) +//│ = A +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = A +//│ 2 fusion opportunities: +//│ AA --match--> `if param0 is ...` +//│ AA --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + +fun c(x, y) = + let t = if x is + AA then + if A is + A then x.x + t + y +c(AA(2), 10) +//│ = 12 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 12 +//│ 2 fusion opportunities: +//│ A --match--> `if scrut is ...` +//│ AA --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< diff --git a/hkmc2/shared/src/test/mlscript/deforest/simple.mls b/hkmc2/shared/src/test/mlscript/deforest/simple.mls new file mode 100644 index 000000000..abd673d70 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/deforest/simple.mls @@ -0,0 +1,584 @@ +:js +:deforest + +//│ No fusion opportunity + +object A +object B +object C +data class AA(aa) +data class BB(bb) +data class AAA(x, y) +data class BBB(x, y) +data class CCC(c) +object None +data class Some(value) + +fun test() = + let x = if true then A else B + if x is + A then 1 + B then 2 +test() +//│ = 1 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 1 +//│ 2 fusion opportunities: +//│ A --match--> `if x is ...` +//│ B --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + +fun test() = + let x = if true then AA(A) else BB(B) + if x is + AA(x) then x + BB(x) then x +test() +//│ = A +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = A +//│ 2 fusion opportunities: +//│ AA --match--> `if x is ...` +//│ BB --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + +fun f(a) = if a is + A then 1 + B then 2 +fun test() = + let x = if true then AA(A) else BB(B) + if x is + AA(x) then f(x) + BB(x) then f(x) +test() +//│ = 1 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 1 +//│ 4 fusion opportunities: +//│ A --match--> `if a is ...` +//│ AA --match--> `if x is ...` +//│ B --match--> `if a is ...` +//│ BB --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + +// `x.x` is successfully fused +fun f1(a) = if a is + A then 1 + B then 2 + C then 3 +fun f2(a) = if a is + A then 4 + B then 5 + C then 6 +fun test() = + let x = if true then AA(A) else BB(B) + if x is + AA then f1(x.aa) + BB then f2(x.bb) +test() +//│ = 1 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 1 +//│ 4 fusion opportunities: +//│ A --match--> `if a is ...` +//│ AA --match--> `if x is ...` +//│ B --match--> `if a is ...` +//│ BB --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + +fun test() = + fun g(x) = if true then AA(11) else BB(22) + fun c(x) = if x is + AA(x) then x + BB(x) then x + c(g(true)) +test() +//│ = 11 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 11 +//│ 2 fusion opportunities: +//│ AA --match--> `if x is ...` +//│ BB --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + + +// multiple match, no fusion +fun test() = + let x = B + if x is + A then 1 + B then 3 + if x is + B then 2 +test() +//│ = 2 +//│ No fusion opportunity + + + +fun test() = + let x = A + let y = B + if x is + A then 1 + if y is + B then 2 +test() +//│ = 2 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 2 +//│ 2 fusion opportunities: +//│ A --match--> `if x is ...` +//│ B --match--> `if y is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + +fun c(a) = + let x = if a is + A then 1 + B then 2 + print(x) + x +c(A) + c(B) +//│ > 1 +//│ > 2 +//│ = 3 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ > 1 +//│ > 2 +//│ = 3 +//│ 2 fusion opportunities: +//│ A --match--> `if a is ...` +//│ B --match--> `if a is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + +// simple free var example + + + + + +:expect 5 +fun c(x) = + let t = x.aa + let n = if x is + AA then 2 + n + t +c(AA(3)) +//│ = 5 +//│ No fusion opportunity + + + +fun f(a, b) = if a is + A then if b is + B then 3 +f(A, B) +//│ = 3 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 3 +//│ 2 fusion opportunities: +//│ A --match--> `if a is ...` +//│ B --match--> `if b is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + +fun f(x) = if x is + Some then if x.value > 1 then f(Some(x.value - 1)) else 0 +f(Some(2)) +//│ = 0 +//│ No fusion opportunity + + + +let x = A +let y = B +if x is + A then 1 +if y is + B then 2 +//│ = 2 +//│ x = A +//│ y = B +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 2 +//│ 2 fusion opportunities: +//│ A --match--> `if x is ...` +//│ B --match--> `if y is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + +fun test() = + let x = A + let y = B + if x is + A then 1 + if y is + B then 2 +test() +//│ = 2 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 2 +//│ 2 fusion opportunities: +//│ A --match--> `if x is ...` +//│ B --match--> `if y is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + +fun f(x, y) = + let a = if x is + AAA(n, m) then y + n - m + BBB(n, m) then m + 1 - n + a + 3 +f(AAA(1, 3), 1) + f(BBB(2, 3), 2) + f(AAA(3, 2), 4) + f(BBB(4, 6), 0) +//│ = 21 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 21 +//│ 4 fusion opportunities: +//│ AAA --match--> `if x is ...` +//│ AAA --match--> `if x is ...` +//│ BBB --match--> `if x is ...` +//│ BBB --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + +fun c1(x) = if x is + AA then + if A is + A then x.aa +fun c2(x) = if x is + AA then x.aa +fun p(a) = c1(a) + c2(a) +p(AA(1)) +//│ = 2 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 2 +//│ 1 fusion opportunities: +//│ A --match--> `if scrut is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + +fun c(x, y) = if x is + AA(a) then + if y is + A then a +c(AA(2), A) +//│ = 2 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 2 +//│ 2 fusion opportunities: +//│ A --match--> `if y is ...` +//│ AA --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + +fun f(x, y) = + let a = if x is + AAA(n, m) then y + n - m + CCC(n) then n + 1 + a + 3 +f(AAA(1, 3), 1) + f(CCC(2), 2) + f(AAA(3, 2), 4) + f(CCC(4), 0) +//│ = 24 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 24 +//│ 4 fusion opportunities: +//│ AAA --match--> `if x is ...` +//│ AAA --match--> `if x is ...` +//│ CCC --match--> `if x is ...` +//│ CCC --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + + +fun f(x, y) = + let a = if x is + AAA then y + x.y + CCC(n) then n + 1 + a + 3 +f(AAA(1, 3), 1) + f(CCC(2), 2) + f(AAA(3, 2), 4) +//│ = 22 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 22 +//│ 3 fusion opportunities: +//│ AAA --match--> `if x is ...` +//│ AAA --match--> `if x is ...` +//│ CCC --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + +fun c(x, m) = if x is + AA(n) then n + 1 + else m +c(BB(3), 0) +//│ = 0 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 0 +//│ 1 fusion opportunities: +//│ BB --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + + +fun f(x) = if x is + AA(b) then g(b) +fun g(b) = if b is + AA(a) then a + 1 +f(AA(AA(0))) +//│ = 1 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 1 +//│ 2 fusion opportunities: +//│ AA --match--> `if b is ...` +//│ AA --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + +fun f(x) = x.aa.aa +f(AA(AA(3))) +//│ = 3 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 3 +//│ 2 fusion opportunities: +//│ AA --sel--> `.Ident(aa)` +//│ AA --sel--> `.Ident(aa)` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + +fun inner(x, y) = + let t = if x is + AA(a) then a + y + BB(b) then b - y + t * y +fun outer1(x, y, z) = + let t = if x is + AA(a) then inner(a, y) + BB(b) then inner(b, y) + t + z +fun outer2(x, y, z) = + let t = if x is + AA(a) then inner(a, z) + BB(b) then inner(b, y) + t + y +let p = AA(AA(3)) +outer1(p, 1, 2) + outer2(p, 3, 4) + inner(AA(5), 6) +//│ = 103 +//│ p = AA(AA(3)) +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 103 +//│ 2 fusion opportunities: +//│ AA --match--> `if x is ...` +//│ AA --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + +fun mapHead(f, x) = if x is + AAA(h, t) then f(h) +mapHead(x => x, AAA(1, AAA(2, None))) +//│ = 1 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 1 +//│ 1 fusion opportunities: +//│ AAA --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + +fun test() = + let k + if A is + A then + k = 3 + k + 2 +test() +//│ = 5 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 5 +//│ 1 fusion opportunities: +//│ A --match--> `if scrut is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + +fun test(a) = if a is + AA then 1 + else 2 +test(B) + test(C) +//│ = 4 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 4 +//│ 2 fusion opportunities: +//│ B --match--> `if a is ...` +//│ C --match--> `if a is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + +// technically ill-typed, so `AA(3)` (which flows to `z`) is not fused +fun test(x, y, z) = if x is + AA(a) then + let m = if y is + AA(a1) then a1 + z + let n = if z is + AA(a2) then a2 - z + m + n +test(AA(1), AA(2), AA(3)) +//│ = "2AA(3)NaN" +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = "2AA(3)NaN" +//│ 2 fusion opportunities: +//│ AA --match--> `if x is ...` +//│ AA --match--> `if y is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + +fun test(x) = + let t = id of + if x is + AA(a) then a + BB(a) then a + t(123).c + 5 * 4 - 3 + 2 - 1 +let p = if true then AA(CCC) else BB(CCC) +test(p) +//│ = 141 +//│ p = AA([function CCC]) +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 141 +//│ 2 fusion opportunities: +//│ AA --match--> `if x is ...` +//│ BB --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + +fun test(x, y) = + let t = + if y is CCC then + if x is + AA(a) then a + BB(a) then a + t(123).c + 5 * 4 - 3 + 2 - 1 +let p = if true then AA(CCC) else BB(CCC) +test(p, id(CCC(123))) +//│ = 141 +//│ p = AA([function CCC]) +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 141 +//│ 2 fusion opportunities: +//│ AA --match--> `if x is ...` +//│ BB --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + +:sjs +fun f(x) = if x is + A then 1 +f(id(A)) + f(A) +//│ JS (unsanitized): +//│ let f10, tmp111, tmp112, tmp113; +//│ f10 = function f(x1) { +//│ if (x1 instanceof A1.class) { +//│ return 1 +//│ } else { +//│ throw new globalThis.Error("match error"); +//│ } +//│ }; +//│ tmp111 = Predef.id(A1); +//│ tmp112 = f10(tmp111); +//│ tmp113 = f10(A1); +//│ tmp112 + tmp113 +//│ No fusion opportunity +//│ = 2 +//│ No fusion opportunity + + + +fun f(x) = if x is + A then 0 + B then 1 +fun g(x) = if x is + AA then A + BB then B +f(g(AA(1))) +//│ = 0 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 0 +//│ 3 fusion opportunities: +//│ A --match--> `if x is ...` +//│ AA --match--> `if x is ...` +//│ B --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + +fun f(a, b) = if a is + A and b is B then 1 + else 0 +f(A, B) +//│ = 1 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 1 +//│ 2 fusion opportunities: +//│ A --match--> `if a is ...` +//│ B --match--> `if b is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + +// cannot fuse intermediate values created by +// calling data constructors passed around like functions, +// so `c1` is not fused. +fun app(f) = f(1) +fun identity(x) = x +fun c1(x) = if x is + AA(i) then i +fun c2(x) = if x is + AA(i) then i + A then 0 +c1(app(AA)) + c1(AA(2)) + c2(AA(3)) + c2(A) +//│ = 6 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 6 +//│ 2 fusion opportunities: +//│ A --match--> `if x is ...` +//│ AA --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + +:re +fun p(x) = if x + then new AA(2) + else BB(3) +fun c(x) = if x is + AA(a) then a + BB(b) then b + else throw (if A is A then Error("e1") else Error("e2")) +print(c(p(true)) + c(p(false))) +c(B) +//│ > 5 +//│ ═══[RUNTIME ERROR] Error: e1 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ > 5 +//│ ═══[RUNTIME ERROR] Error: e1 +//│ 4 fusion opportunities: +//│ A --match--> `if scrut is ...` +//│ AA --match--> `if x is ...` +//│ B --match--> `if x is ...` +//│ BB --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + +fun f(a, b) = if a is + AA(x) then x + b.bb +f(AA(1), BB(2)) +//│ = 3 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 3 +//│ 2 fusion opportunities: +//│ AA --match--> `if a is ...` +//│ BB --sel--> `.Ident(bb)` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< diff --git a/hkmc2/shared/src/test/mlscript/deforest/todos.mls b/hkmc2/shared/src/test/mlscript/deforest/todos.mls new file mode 100644 index 000000000..1dc8dfcb2 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/deforest/todos.mls @@ -0,0 +1,224 @@ +:js +:deforest + +//│ No fusion opportunity +object A +object B +object C +data class AA(x) +data class BB(x) +data class CC(x) + +object Nil +data class Cons(h, t) + +object None +data class Some(x) + + +// since `x.x` is considered as being consumed +// at two places,the inner match is not fused +:sjs +fun f(x) = if x is + AA then if x.x is + AA then x.x.x +f(AA(AA(3))) +//│ JS (unsanitized): +//│ let f, tmp, tmp1; +//│ f = function f(x) { +//│ let scrut; +//│ if (x instanceof AA1.class) { +//│ scrut = x.x; +//│ if (scrut instanceof AA1.class) { +//│ return x.x.x +//│ } else { +//│ throw new globalThis.Error("match error"); +//│ } +//│ } else { +//│ throw new globalThis.Error("match error"); +//│ } +//│ }; +//│ tmp = AA1(3); +//│ tmp1 = AA1(tmp); +//│ f(tmp1) +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ ==== JS (deforested): ==== +//│ let f, tmp, tmp1, _deforest_AA_x; +//│ f = function f(x) { +//│ return runtime.safeCall(x()) +//│ }; +//│ tmp = AA1(3); +//│ _deforest_AA_x = tmp; +//│ tmp1 = () => { +//│ let scrut; +//│ scrut = _deforest_AA_x; +//│ if (scrut instanceof AA1.class) { +//│ return _deforest_AA_x.x +//│ } else { +//│ throw new this.Error("match error"); +//│ } +//│ }; +//│ f(tmp1) +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< +//│ = 3 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 3 +//│ 1 fusion opportunities: +//│ AA --match--> `if x is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + +// TODO: the computation of `t + 5 * 4 - 3 + 2 - 1` is still duplicated... +:sjs +fun test(x) = + let t = if x is + AA(AA(AA(a))) then a + else 4 + t + 5 * 4 - 3 + 2 - 1 +fun f(a) = if a is + AA(AA) then 0 +let p = AA(AA(AA(10))) +test(p) + f(p) + test(AA(AA(AA(10)))) + test(B) +//│ JS (unsanitized): +//│ let test, f1, p, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9, tmp10, tmp11, tmp12, tmp13, tmp14, tmp15; +//│ test = function test(x) { +//│ let t, param0, param01, param02, a, tmp16, tmp17, tmp18, tmp19, tmp20; +//│ if (x instanceof AA1.class) { +//│ param0 = x.x; +//│ if (param0 instanceof AA1.class) { +//│ param01 = param0.x; +//│ if (param01 instanceof AA1.class) { +//│ param02 = param01.x; +//│ a = param02; +//│ tmp16 = a; +//│ } else { +//│ tmp16 = 4; +//│ } +//│ } else { +//│ tmp16 = 4; +//│ } +//│ } else { +//│ tmp16 = 4; +//│ } +//│ t = tmp16; +//│ tmp17 = 5 * 4; +//│ tmp18 = t + tmp17; +//│ tmp19 = tmp18 - 3; +//│ tmp20 = tmp19 + 2; +//│ return tmp20 - 1 +//│ }; +//│ f1 = function f(a) { +//│ let param0; +//│ if (a instanceof AA1.class) { +//│ param0 = a.x; +//│ if (param0 instanceof AA1.class) { +//│ return 0 +//│ } else { +//│ throw new globalThis.Error("match error"); +//│ } +//│ } else { +//│ throw new globalThis.Error("match error"); +//│ } +//│ }; +//│ tmp4 = AA1(10); +//│ tmp5 = AA1(tmp4); +//│ tmp6 = AA1(tmp5); +//│ p = tmp6; +//│ tmp7 = test(p); +//│ tmp8 = f1(p); +//│ tmp9 = tmp7 + tmp8; +//│ tmp10 = AA1(10); +//│ tmp11 = AA1(tmp10); +//│ tmp12 = AA1(tmp11); +//│ tmp13 = test(tmp12); +//│ tmp14 = tmp9 + tmp13; +//│ tmp15 = test(B1); +//│ tmp14 + tmp15 +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ ==== JS (deforested): ==== +//│ let test, f1, p, tmp4, tmp5, tmp6, tmp7, tmp8, tmp9, tmp10, tmp11, tmp12, tmp13, tmp14, tmp15, _deforest_AA_x_tmp, match_param0_branch_AA, _deforest_AA_x_tmp1; +//│ match_param0_branch_AA = function match_param0_branch_AA(_deforest_AA_x1) { +//│ let t, param0, a, tmp16, tmp17, tmp18, tmp19, tmp20; +//│ param0 = _deforest_AA_x1; +//│ a = param0; +//│ tmp16 = a; +//│ t = tmp16; +//│ tmp17 = 5 * 4; +//│ tmp18 = t + tmp17; +//│ tmp19 = tmp18 - 3; +//│ tmp20 = tmp19 + 2; +//│ return tmp20 - 1 +//│ }; +//│ test = function test(x) { +//│ let t, param0, param01, tmp16, tmp17, tmp18, tmp19, tmp20; +//│ if (x instanceof AA1.class) { +//│ param0 = x.x; +//│ if (param0 instanceof AA1.class) { +//│ param01 = param0.x; +//│ return runtime.safeCall(param01()) +//│ } else { +//│ tmp16 = 4; +//│ } +//│ } else { +//│ tmp16 = 4; +//│ } +//│ t = tmp16; +//│ tmp17 = 5 * 4; +//│ tmp18 = t + tmp17; +//│ tmp19 = tmp18 - 3; +//│ tmp20 = tmp19 + 2; +//│ return tmp20 - 1 +//│ }; +//│ f1 = function f(a) { +//│ let param0; +//│ if (a instanceof AA1.class) { +//│ param0 = a.x; +//│ if (param0 instanceof AA1.class) { +//│ return 0 +//│ } else { +//│ throw new globalThis.Error("match error"); +//│ } +//│ } else { +//│ throw new globalThis.Error("match error"); +//│ } +//│ }; +//│ _deforest_AA_x_tmp = 10; +//│ tmp4 = () => { +//│ return match_param0_branch_AA(_deforest_AA_x_tmp) +//│ }; +//│ tmp5 = AA1(tmp4); +//│ tmp6 = AA1(tmp5); +//│ p = tmp6; +//│ tmp7 = test(p); +//│ tmp8 = f1(p); +//│ tmp9 = tmp7 + tmp8; +//│ _deforest_AA_x_tmp1 = 10; +//│ tmp10 = () => { +//│ return match_param0_branch_AA(_deforest_AA_x_tmp1) +//│ }; +//│ tmp11 = AA1(tmp10); +//│ tmp12 = AA1(tmp11); +//│ tmp13 = test(tmp12); +//│ tmp14 = tmp9 + tmp13; +//│ tmp15 = test(B1); +//│ tmp14 + tmp15 +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< +//│ = 78 +//│ p = AA(AA(AA(10))) +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = 78 +//│ 2 fusion opportunities: +//│ AA --match--> `if param0 is ...` +//│ AA --match--> `if param0 is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + +// TODO: currently nothing is done at all when seeing +// `Define` blocks defining things other than functions +data class Global(x) +fun test() = + data class Local(x) + if Global(1) is + Global(x) then Local(x + 1) +test() +//│ = Local(2) diff --git a/hkmc2/shared/src/test/mlscript/deforest/zipunzip.mls b/hkmc2/shared/src/test/mlscript/deforest/zipunzip.mls new file mode 100644 index 000000000..73097e546 --- /dev/null +++ b/hkmc2/shared/src/test/mlscript/deforest/zipunzip.mls @@ -0,0 +1,78 @@ +:js +:deforest + +//│ No fusion opportunity + + +object Nil +data class (::) Cons(h, t) +data class Pair(a, b) + +fun zip(xs_zip, ys_zip) = if + xs_zip is x :: xt and ys_zip is y :: yt then Pair(x, y) :: zip(xt, yt) + else Nil +fun unzip(ls_unzip) = if ls_unzip is + Pair(a, b) :: t and unzip(t) is Pair(atail, btail) then Pair(a :: atail, b :: btail) + Nil then Pair(Nil, Nil) +fun enumFromTo(a, b) = if a < b then a :: enumFromTo(a + 1, b) else Nil +fun testUnzipZip(n) = unzip(zip(id(enumFromTo(1, n)), id(enumFromTo(2, n + 3)))) +testUnzipZip(3) +//│ = Pair(Cons(1, Cons(2, Nil)), Cons(2, Cons(3, Nil))) +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = Pair(Cons(1, Cons(2, Nil)), Cons(2, Cons(3, Nil))) +//│ 4 fusion opportunities: +//│ Cons --match--> `if ls_unzip is ...` +//│ Nil --match--> `if ls_unzip is ...` +//│ Nil --match--> `if ls_unzip is ...` +//│ Pair --match--> `if param0 is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + +fun zip(xs_zip, ys_zip) = if + xs_zip is x :: xt and ys_zip is y :: yt then Pair(x, y) :: zip(xt, yt) + else Nil +fun unzip(ls_unzip) = if ls_unzip is + Pair(a, b) :: t and unzip(t) is Pair(atail, btail) then Pair(a :: atail, b :: btail) + Nil then Pair(Nil, Nil) +fun makeZippedList(n) = if n > 0 then Pair(n, n + 1) :: makeZippedList(n - 1) else Nil +fun testZipUnzip(n) = if unzip(id(makeZippedList(n))) is + Pair(xs, ys) then zip(xs, ys) +testZipUnzip(3) +//│ = Cons(Pair(3, 4), Cons(Pair(2, 3), Cons(Pair(1, 2), Nil))) +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = Cons(Pair(3, 4), Cons(Pair(2, 3), Cons(Pair(1, 2), Nil))) +//│ 4 fusion opportunities: +//│ Cons --match--> `if xs_zip is ...` +//│ Cons --match--> `if ys_zip is ...` +//│ Nil --match--> `if xs_zip is ...` +//│ Nil --match--> `if ys_zip is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< + + + +fun zip(xs_zip, ys_zip) = if + xs_zip is x :: xt and ys_zip is y :: yt then Pair(x, y) :: zip(xt, yt) + else Nil +fun unzip(ls_unzip) = if ls_unzip is + Pair(a, b) :: t and unzip(t) is Pair(atail, btail) then Pair(a :: atail, b :: btail) + Nil then Pair(Nil, Nil) +fun map(f, ls_map) = if ls_map is + h :: t then f(h) :: map(f, t) + Nil then Nil +fun makeZippedList(n) = if n > 0 then Cons(Pair(n, n + 1), makeZippedList(n - 1)) else Nil +fun testZipMapBothUnzip(n) = if unzip(id(makeZippedList(n))) is + Pair(xs, ys) then zip( + map(x => x + 1, xs), + map(x => x * x, ys) + ) +testZipMapBothUnzip(4) +//│ = Cons(Pair(5, 25), Cons(Pair(4, 16), Cons(Pair(3, 9), Cons(Pair(2, 4), Nil)))) +//│ >>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>> +//│ = Cons(Pair(5, 25), Cons(Pair(4, 16), Cons(Pair(3, 9), Cons(Pair(2, 4), Nil)))) +//│ 4 fusion opportunities: +//│ Cons --match--> `if ls_map is ...` +//│ Cons --match--> `if ls_map is ...` +//│ Nil --match--> `if ls_map is ...` +//│ Nil --match--> `if ls_map is ...` +//│ <<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<< diff --git a/hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala b/hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala index 93ff05813..e8aec6ee6 100644 --- a/hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala +++ b/hkmc2DiffTests/src/test/scala/hkmc2/JSBackendDiffMaker.scala @@ -24,6 +24,8 @@ abstract class JSBackendDiffMaker extends MLsDiffMaker: val showJS = NullaryCommand("sjs") val showRepl = NullaryCommand("showRepl") val traceJS = NullaryCommand("traceJS") + val deforestFlag = NullaryCommand("deforest") + val deforestInfo = NullaryCommand("deforestInfo") val expect = Command("expect"): ln => ln.trim @@ -41,6 +43,10 @@ abstract class JSBackendDiffMaker extends MLsDiffMaker: override def doTrace = showRepl.isSet override def emitDbg(str: String): Unit = output(str) + val deforestTL = new TraceLogger: + override def doTrace: Bool = deforestInfo.isSet + override def emitDbg(str: String): Unit = output(str) + lazy val host = hostCreated = true given TL = replTL @@ -63,6 +69,7 @@ abstract class JSBackendDiffMaker extends MLsDiffMaker: val outerRaise: Raise = summon val reportedMessages = mutable.Set.empty[Str] + var correctResult: Opt[Str] = None if showJS.isSet then given Raise = @@ -82,6 +89,27 @@ abstract class JSBackendDiffMaker extends MLsDiffMaker: val jsStr = je.stripBreaks.mkString(100) output(s"JS (unsanitized):") output(jsStr) + + if deforestFlag.isSet then + val deforest = new Deforest(using deforestTL) + val deforestRes -> _ -> num = deforest(le) + deforestRes match + case None => () + case Some(_) if num == 0 => output("No fusion opportunity") + case Some(deforestRes) => + output(">>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>>") + if showLoweredTree.isSet then + output("\n==== deforested tree ====") + output(deforestRes.showAsTree) + output("\n") + + val je = baseScp.nest.givenIn: + jsb.program(deforestRes, N, wd) + output("==== JS (deforested): ====") + val jsStr = je.stripBreaks.mkString(100) + output(jsStr) + output("<<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<<") + if js.isSet then given Elaborator.Ctx = curCtx given Raise = @@ -96,43 +124,40 @@ abstract class JSBackendDiffMaker extends MLsDiffMaker: val jsb = ltl.givenIn: new JSBuilder with JSBuilderArgNumSanityChecks - val resSym = new TempSymbol(S(blk), "block$res") - val lowered0 = low.program(blk) - val le = lowered0.copy(main = lowered0.main.mapTail: - case e: End => - Assign(resSym, Value.Lit(syntax.Tree.UnitLit(false)), e) - case Return(res, implct) => - assert(implct) - Assign(resSym, res, Return(Value.Lit(syntax.Tree.UnitLit(false)), true)) - case tl: (Throw | Break | Continue) => tl - ) - if showLoweredTree.isSet then - output(s"Lowered:") - output(le.showAsTree) - // * We used to do this to avoid needlessly generating new variable names in separate blocks: - // val nestedScp = baseScp.nest - val nestedScp = baseScp - // val nestedScp = codegen.js.Scope(S(baseScp), curCtx.outer, collection.mutable.Map.empty) // * not needed + def getResSymAndResNme(n: Str) = + val resSym = new TempSymbol(S(blk), n) + resSym -> baseScp.allocateName(resSym) - val resNme = nestedScp.allocateName(resSym) + def assignResultSymForBlock(lowered: Program, resSym: TempSymbol) = + lowered.copy(main = lowered.main.mapTail: + case e: End => + Assign(resSym, Value.Lit(syntax.Tree.UnitLit(false)), e) + case Return(res, implct) => + assert(implct) + Assign(resSym, res, Return(Value.Lit(syntax.Tree.UnitLit(false)), true)) + case tl: (Throw | Break | Continue) => tl + ) - if ppLoweredTree.isSet then - output(s"Pretty Lowered:") - output(Printer.mkDocument(le)(using summon[Raise], nestedScp).toString) + def mkJS(le: Program) = + val (pre, js) = baseScp.givenIn: + jsb.worksheet(le) + val preStr = pre.stripBreaks.mkString(100) + val jsStr = js.stripBreaks.mkString(100) + if showSanitizedJS.isSet then + output(s"JS:") + output(jsStr) + preStr -> jsStr - val (pre, js) = nestedScp.givenIn: - jsb.worksheet(le) - val preStr = pre.stripBreaks.mkString(100) - val jsStr = js.stripBreaks.mkString(100) - if showSanitizedJS.isSet then - output(s"JS:") - output(jsStr) - def mkQuery(preStr: Str, jsStr: Str)(k: Str => Unit) = + def mkQuery(preStr: Str, jsStr: Str)(handleResult: Iterable[Str] => Unit) = val queryStr = jsStr.replaceAll("\n", " ") val (reply, stderr) = host.query(preStr, queryStr, !expectRuntimeOrCodeGenErrors && fixme.isUnset && todo.isUnset) reply match - case ReplHost.Result(content) => k(content) + case ReplHost.Result(content) => + val res :+ end = content.splitSane('\n') : @unchecked + // TODO: seems that not all programs end with "undefined" now + // assert(end == "undefined") + handleResult(res) case ReplHost.Empty => case ReplHost.Unexecuted(message) => ??? case ReplHost.Error(isSyntaxError, message, otherOutputs) => @@ -150,20 +175,60 @@ abstract class JSBackendDiffMaker extends MLsDiffMaker: source = Diagnostic.Source.Runtime)) if stderr.nonEmpty then output(s"// Standard Error:\n${stderr}") - if traceJS.isSet then - host.execute( - s"$runtimeNme.TraceLogger.enabled = true; " + - s"$runtimeNme.TraceLogger.resetIndent(0)") + def executeJS(preStr: Str, jsStr: Str, resNme: Str) = + if traceJS.isSet then + host.execute( + s"$runtimeNme.TraceLogger.enabled = true; " + + s"$runtimeNme.TraceLogger.resetIndent(0)") + + // * Sometimes the JS block won't execute due to a syntax or runtime error so we always set this first + host.execute(s"$resNme = undefined") + + mkQuery(preStr, jsStr): stdout => + stdout.foreach: line => + output(s"> ${line}") + if traceJS.isSet then + host.execute(s"$runtimeNme.TraceLogger.enabled = false") - // * Sometimes the JS block won't execute due to a syntax or runtime error so we always set this first - host.execute(s"$resNme = undefined") + def handleDefinedValues(nme: Str, sym: Symbol, expect: Opt[Str])(handleResult: Str => Unit) = + val le = + import codegen.* + Return( + Call( + Value.Ref(Elaborator.State.runtimeSymbol).selSN("printRaw"), + Arg(false, Value.Ref(sym)) :: Nil)(true, false), + implct = true) + val je = baseScp.givenIn: + jsb.block(le, endSemi = false) + val jsStr = je.stripBreaks.mkString(100) + mkQuery("", jsStr): out => + val result = out.mkString + expect match + case S(expected) if result =/= expected => raise: + ErrorReport(msg"Expected: '${expected}', got: '${result}'" -> N :: Nil, + source = Diagnostic.Source.Runtime) + case _ => () + val anon = nme.isEmpty + handleResult(result) + result match + case "undefined" if anon => + case "()" if anon => + case _ => + output(s"${if anon then "" else s"$nme "}= ${result.indentNewLines("| ")}") - mkQuery(preStr, jsStr): stdout => - stdout.splitSane('\n').init // should always ends with "undefined" (TODO: check) - .foreach: line => - output(s"> ${line}") - if traceJS.isSet then - host.execute(s"$runtimeNme.TraceLogger.enabled = false") + val lowered0 = low.program(blk) + val resSym -> resNme = getResSymAndResNme("block$res") + val le = assignResultSymForBlock(lowered0, resSym) + if showLoweredTree.isSet then + output(s"Lowered:") + output(le.showAsTree) + + if ppLoweredTree.isSet then + output(s"Pretty Lowered:") + output(Printer.mkDocument(le)(using summon[Raise], baseScp).toString) + + val (preStr, jsStr) = mkJS(le) + executeJS(preStr, jsStr, resNme) if silent.isUnset then import Elaborator.Ctx.* @@ -177,29 +242,31 @@ abstract class JSBackendDiffMaker extends MLsDiffMaker: case _ => N case _ => N val valuesToPrint = ("", resSym, expect.get) +: definedValues.toSeq.sortBy(_._1) - valuesToPrint.foreach: (nme, sym, expect) => - val le = - import codegen.* - Return( - Call( - Value.Ref(Elaborator.State.runtimeSymbol).selSN("printRaw"), - Arg(false, Value.Ref(sym)) :: Nil)(true, false), - implct = true) - val je = nestedScp.givenIn: - jsb.block(le, endSemi = false) - val jsStr = je.stripBreaks.mkString(100) - mkQuery("", jsStr): out => - val result = out.splitSane('\n').init.mkString // should always ends with "undefined" (TODO: check) - expect match - case S(expected) if result =/= expected => raise: - ErrorReport(msg"Expected: '${expected}', got: '${result}'" -> N :: Nil, - source = Diagnostic.Source.Runtime) - case _ => () - val anon = nme.isEmpty - result match - case "undefined" if anon => - case "()" if anon => - case _ => - output(s"${if anon then "" else s"$nme "}= ${result.indentNewLines("| ")}") + valuesToPrint.foreach: (nme, sym, expected) => + handleDefinedValues(nme, sym, expected)(if sym === resSym then r => correctResult = S(r) else _ => ()) - + if deforestFlag.isSet then + val deforestLow = ltl.givenIn: + codegen.Lowering() + val lowered0 = deforestLow.program(blk) + val deforest = new Deforest(using deforestTL) + val maybeDeforestRes -> deforestStat -> num = deforest(lowered0) + maybeDeforestRes match + case None => () + case Some(_) if num == 0 => output("No fusion opportunity") + case Some(deforestRes) => + output(">>>>>>>>>>>>>>>>>>>>>>>>>>> Deforestation >>>>>>>>>>>>>>>>>>>>>>>>>>>") + val resSym -> resNme = getResSymAndResNme("block$res_deforest") + val le = assignResultSymForBlock(deforestRes, resSym) + val (preStr, jsStr) = mkJS(le) + executeJS(preStr, jsStr, resNme) + + if silent.isUnset then + handleDefinedValues("", resSym, expect.get): result => + if correctResult.fold(false)(_ != result) then raise: + ErrorReport( + msg"The result from deforestated program (\"${result}\") is different from the one computed by the original prorgam (\"${correctResult.get}\")" -> N :: Nil, + source = Diagnostic.Source.Runtime) + + output(deforestStat) + output("<<<<<<<<<<<<<<<<<<<<<<<<<<< Deforestation <<<<<<<<<<<<<<<<<<<<<<<<<<<")