Skip to content

Commit 4a5a038

Browse files
committed
Clean up DiffTests and fix its use in compiler subproject
1 parent dd9a0cd commit 4a5a038

File tree

4 files changed

+51
-85
lines changed

4 files changed

+51
-85
lines changed

build.sbt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ lazy val compiler = crossProject(JSPlatform, JVMPlatform).in(file("compiler"))
8383
sourceDirectory := baseDirectory.value.getParentFile()/"shared",
8484
watchSources += WatchSource(
8585
baseDirectory.value.getParentFile()/"shared"/"test"/"diff", "*.mls", NothingFilter),
86+
watchSources += WatchSource(
87+
baseDirectory.value.getParentFile()/"shared"/"test"/"diff-ir", "*.mls", NothingFilter),
8688
)
8789
.dependsOn(mlscript % "compile->compile;test->test")
8890

compiler/shared/test/scala/mlscript/compiler/Test.scala

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@ import scala.collection.mutable.StringBuilder
77
import mlscript.compiler.TreeDebug
88
import simpledef.SimpleDef
99

10-
class DiffTestCompiler extends DiffTests {
11-
import DiffTestCompiler.*
10+
import DiffTestCompiler.*
11+
12+
class DiffTestCompiler extends DiffTests(State) {
13+
1214
override def postProcess(mode: ModeType, basePath: List[Str], testName: Str, unit: TypingUnit, output: Str => Unit, raise: Diagnostic => Unit): (List[Str], Option[TypingUnit]) =
1315
val outputBuilder = StringBuilder()
1416

@@ -47,21 +49,11 @@ class DiffTestCompiler extends DiffTests {
4749
}
4850
None
4951

50-
override protected lazy val files = allFiles.filter { file =>
51-
val fileName = file.baseName
52-
validExt(file.ext) && filter(file.relativeTo(pwd))
53-
}
5452
}
5553

5654
object DiffTestCompiler {
57-
58-
private val pwd = os.pwd
59-
private val dir = pwd/"compiler"/"shared"/"test"/"diff"
6055

61-
private val allFiles = os.walk(dir).filter(_.toIO.isFile)
62-
63-
private val validExt = Set("fun", "mls")
64-
65-
private def filter(file: os.RelPath) = DiffTests.filter(file)
56+
lazy val State =
57+
new DiffTests.State(DiffTests.pwd/"compiler"/"shared"/"test"/"diff")
6658

6759
}

compiler/shared/test/scala/mlscript/compiler/TestIR.scala

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@ import mlscript.compiler.ir._
66
import scala.collection.mutable.StringBuilder
77
import mlscript.compiler.optimizer.TailRecOpt
88

9-
class IRDiffTestCompiler extends DiffTests {
10-
import IRDiffTestCompiler.*
9+
import IRDiffTestCompiler.*
10+
11+
class IRDiffTestCompiler extends DiffTests(State) {
1112

1213
override def postProcess(mode: ModeType, basePath: List[Str], testName: Str, unit: TypingUnit, output: Str => Unit, raise: Diagnostic => Unit): (List[Str], Option[TypingUnit]) =
1314
val outputBuilder = StringBuilder()
@@ -58,21 +59,11 @@ class IRDiffTestCompiler extends DiffTests {
5859

5960
(outputBuilder.toString().linesIterator.toList, None)
6061

61-
override protected lazy val files = allFiles.filter { file =>
62-
val fileName = file.baseName
63-
validExt(file.ext) && filter(file.relativeTo(pwd))
64-
}
6562
}
6663

6764
object IRDiffTestCompiler {
68-
69-
private val pwd = os.pwd
70-
private val dir = pwd/"compiler"/"shared"/"test"/"diff-ir"
7165

72-
private val allFiles = os.walk(dir).filter(_.toIO.isFile)
73-
74-
private val validExt = Set("fun", "mls")
75-
76-
private def filter(file: os.RelPath) = DiffTests.filter(file)
66+
lazy val State =
67+
new DiffTests.State(DiffTests.pwd/"compiler"/"shared"/"test"/"diff-ir")
7768

7869
}

shared/src/test/scala/mlscript/DiffTests.scala

Lines changed: 38 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import org.scalatest.{funsuite, ParallelTestExecution}
1212
import org.scalatest.time._
1313
import org.scalatest.concurrent.{TimeLimitedTests, Signaler}
1414
import pretyper.PreTyper
15+
import os.Path
1516

1617
abstract class ModeType {
1718
def expectTypeErrors: Bool
@@ -50,12 +51,15 @@ abstract class ModeType {
5051
def nolift: Bool
5152
}
5253

53-
class DiffTests
54+
class DiffTests(state: DiffTests.State)
5455
extends funsuite.AnyFunSuite
5556
with ParallelTestExecution
5657
with TimeLimitedTests
5758
{
5859

60+
def this() = this(DiffTests.State)
61+
62+
import state._
5963

6064
/** Hook for dependent projects, like the monomorphizer. */
6165
def postProcess(mode: ModeType, basePath: Ls[Str], testName: Str, unit: TypingUnit, output: Str => Unit, raise: Diagnostic => Unit): (Ls[Str], Option[TypingUnit]) = (Nil, None)
@@ -65,14 +69,12 @@ class DiffTests
6569
@SuppressWarnings(Array("org.wartremover.warts.RedundantIsInstanceOf"))
6670
private val inParallel = isInstanceOf[ParallelTestExecution]
6771

68-
import DiffTests._
69-
7072
// scala test will not execute a test if the test class has constructor parameters.
7173
// override this to get the correct paths of test files.
7274
protected lazy val files = allFiles.filter { file =>
7375
val fileName = file.baseName
7476
// validExt(file.ext) && filter(fileName)
75-
validExt(file.ext) && filter(file.relativeTo(pwd))
77+
validExt(file.ext) && filter(file.relativeTo(DiffTests.pwd))
7678
}
7779

7880
val timeLimit = TimeLimit
@@ -240,7 +242,7 @@ class DiffTests
240242
case "p" => mode.copy(showParse = true)
241243
case "d" => mode.copy(dbg = true)
242244
case "dp" => mode.copy(dbgParsing = true)
243-
case DebugUCSFlags(x) => mode.copy(dbgUCS = mode.dbgUCS.fold(S(x))(y => S(y ++ x)))
245+
case DiffTests.DebugUCSFlags(x) => mode.copy(dbgUCS = mode.dbgUCS.fold(S(x))(y => S(y ++ x)))
244246
case "ds" => mode.copy(dbgSimplif = true)
245247
case "dl" => mode.copy(dbgLifting = true)
246248
case "dd" => mode.copy(dbgDefunc = true)
@@ -1139,61 +1141,40 @@ class DiffTests
11391141

11401142
object DiffTests {
11411143

1142-
private val TimeLimit =
1143-
if (sys.env.get("CI").isDefined) Span(60, Seconds)
1144-
else Span(30, Seconds)
1145-
1146-
private val pwd = os.pwd
1147-
private val dir = pwd/"shared"/"src"/"test"/"diff"
1148-
1149-
private val allFiles = os.walk(dir).filter(_.toIO.isFile)
1144+
val pwd: Path = os.pwd
11501145

1151-
private val validExt = Set("fun", "mls")
1146+
lazy val State = new State(pwd/"shared"/"src"/"test"/"diff")
11521147

1153-
// Aggregate unstaged modified files to only run the tests on them, if there are any
1154-
private val modified: Set[os.RelPath] =
1155-
try os.proc("git", "status", "--porcelain", dir).call().out.lines().iterator.flatMap { gitStr =>
1156-
println(" [git] " + gitStr)
1157-
val prefix = gitStr.take(2)
1158-
val filePath = os.RelPath(gitStr.drop(3))
1159-
if (prefix =:= "A " || prefix =:= "M " || prefix =:= "R " || prefix =:= "D ")
1160-
N // * Disregard modified files that are staged
1161-
else S(filePath)
1162-
}.toSet catch {
1163-
case err: Throwable => System.err.println("/!\\ git command failed with: " + err)
1164-
Set.empty
1165-
}
1166-
1167-
// Allow overriding which specific tests to run, sometimes easier for development:
1168-
private val focused = Set[Str](
1169-
// "LetRec"
1170-
// "Ascribe",
1171-
// "Repro",
1172-
// "RecursiveTypes",
1173-
// "Simple",
1174-
// "Inherit",
1175-
// "Basics",
1176-
// "Paper",
1177-
// "Negations",
1178-
// "RecFuns",
1179-
// "With",
1180-
// "Annoying",
1181-
// "Tony",
1182-
// "Lists",
1183-
// "Traits",
1184-
// "BadTraits",
1185-
// "TraitMatching",
1186-
// "Subsume",
1187-
// "Methods",
1188-
).map(os.RelPath(_))
1189-
// private def filter(name: Str): Bool =
1190-
def filter(file: os.RelPath): Bool = {
1191-
if (focused.nonEmpty) focused(file) else modified(file) || modified.isEmpty &&
1192-
true
1193-
// name.startsWith("new/")
1194-
// file.segments.toList.init.lastOption.contains("parser")
1148+
class State(val dir: Path) {
1149+
1150+
val TimeLimit: Span =
1151+
if (sys.env.get("CI").isDefined) Span(60, Seconds)
1152+
else Span(30, Seconds)
1153+
1154+
val allFiles: IndexedSeq[Path] = os.walk(dir).filter(_.toIO.isFile)
1155+
1156+
val validExt: Set[String] = Set("fun", "mls")
1157+
1158+
// Aggregate unstaged modified files to only run the tests on them, if there are any
1159+
val modified: Set[os.RelPath] =
1160+
try os.proc("git", "status", "--porcelain", dir).call().out.lines().iterator.flatMap { gitStr =>
1161+
println(" [git] " + gitStr)
1162+
val prefix = gitStr.take(2)
1163+
val filePath = os.RelPath(gitStr.drop(3))
1164+
if (prefix =:= "A " || prefix =:= "M " || prefix =:= "R " || prefix =:= "D ")
1165+
N // * Disregard modified files that are staged
1166+
else S(filePath)
1167+
}.toSet catch {
1168+
case err: Throwable => System.err.println("/!\\ git command failed with: " + err)
1169+
Set.empty
1170+
}
1171+
1172+
// private def filter(name: Str): Bool =
1173+
def filter(file: os.RelPath): Bool =
1174+
modified(file) || modified.isEmpty
1175+
11951176
}
1196-
1177+
11971178
object DebugUCSFlags {
11981179
// E.g. "ducs", "ducs:foo", "ducs:foo,bar", "ducs:a.b.c,foo"
11991180
private val pattern = "^ducs(?::(\\s*(?:[A-Za-z\\.-]+)(?:,\\s*[A-Za-z\\.-]+)*))?$".r

0 commit comments

Comments
 (0)