Skip to content

Commit 4a2f47f

Browse files
authored
[query] Move LoweredTableReaderCoercer into ExecuteContext (hail-is#14696)
Refactored table reader coercion and caching mechanism. ### What changed? - Removed `shouldCacheQueryInfo` method from `Backend` class - Introduced `CoercerCache` in `ExecuteContext` - Refactored `LoweredTableReader.makeCoercer` to return a function instead of a class - Removed local caching in `GenericTableValue` and `LoweredTableReader` ### Why make this change? This change aims to optimize table reader coercion by: - Centralizing caching logic in `ExecuteContext` - Allowing more flexible caching strategies across different backend implementations
1 parent 9bae13e commit 4a2f47f

File tree

7 files changed

+360
-377
lines changed

7 files changed

+360
-377
lines changed

hail/hail/src/is/hail/backend/Backend.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,6 @@ abstract class Backend extends Closeable {
8585
def asSpark(op: String): SparkBackend =
8686
fatal(s"${getClass.getSimpleName}: $op requires SparkBackend")
8787

88-
def shouldCacheQueryInfo: Boolean = true
89-
9088
def lowerDistributedSort(
9189
ctx: ExecuteContext,
9290
stage: TableStage,

hail/hail/src/is/hail/backend/ExecuteContext.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import is.hail.annotations.{Region, RegionPool}
55
import is.hail.asm4s.HailClassLoader
66
import is.hail.backend.local.LocalTaskContext
77
import is.hail.expr.ir.{BaseIR, CodeCacheKey, CompiledFunction}
8+
import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer
89
import is.hail.expr.ir.lowering.IrMetadata
910
import is.hail.io.fs.FS
1011
import is.hail.linalg.BlockMatrix
@@ -76,6 +77,7 @@ object ExecuteContext {
7677
blockMatrixCache: mutable.Map[String, BlockMatrix],
7778
codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]],
7879
irCache: mutable.Map[Int, BaseIR],
80+
coercerCache: mutable.Map[Any, LoweredTableReaderCoercer],
7981
)(
8082
f: ExecuteContext => T
8183
): T = {
@@ -97,6 +99,7 @@ object ExecuteContext {
9799
blockMatrixCache,
98100
codeCache,
99101
irCache,
102+
coercerCache,
100103
))(f(_))
101104
}
102105
}
@@ -129,6 +132,7 @@ class ExecuteContext(
129132
val BlockMatrixCache: mutable.Map[String, BlockMatrix],
130133
val CodeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]],
131134
val PersistedIrCache: mutable.Map[Int, BaseIR],
135+
val PersistedCoercerCache: mutable.Map[Any, LoweredTableReaderCoercer],
132136
) extends Closeable {
133137

134138
val rngNonce: Long =
@@ -198,6 +202,7 @@ class ExecuteContext(
198202
blockMatrixCache: mutable.Map[String, BlockMatrix] = this.BlockMatrixCache,
199203
codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]] = this.CodeCache,
200204
persistedIrCache: mutable.Map[Int, BaseIR] = this.PersistedIrCache,
205+
persistedCoercerCache: mutable.Map[Any, LoweredTableReaderCoercer] = this.PersistedCoercerCache,
201206
)(
202207
f: ExecuteContext => A
203208
): A =
@@ -217,5 +222,6 @@ class ExecuteContext(
217222
blockMatrixCache,
218223
codeCache,
219224
persistedIrCache,
225+
persistedCoercerCache,
220226
))(f)
221227
}

hail/hail/src/is/hail/backend/local/LocalBackend.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import is.hail.backend._
77
import is.hail.backend.py4j.Py4JBackendExtensions
88
import is.hail.expr.Validate
99
import is.hail.expr.ir._
10+
import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer
1011
import is.hail.expr.ir.analyses.SemanticHash
1112
import is.hail.expr.ir.compile.Compile
1213
import is.hail.expr.ir.defs.MakeTuple
@@ -94,6 +95,7 @@ class LocalBackend(
9495
private[this] val theHailClassLoader = new HailClassLoader(getClass.getClassLoader)
9596
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)
9697
private[this] val persistedIR: mutable.Map[Int, BaseIR] = mutable.Map()
98+
private[this] val coercerCache = new Cache[Any, LoweredTableReaderCoercer](32)
9799

98100
// flags can be set after construction from python
99101
def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags))
@@ -119,6 +121,7 @@ class LocalBackend(
119121
ImmutableMap.empty,
120122
codeCache,
121123
persistedIR,
124+
coercerCache,
122125
)(f)
123126
}
124127

hail/hail/src/is/hail/backend/service/ServiceBackend.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,6 @@ class ServiceBackend(
136136
private[this] val MAX_AVAILABLE_GCS_CONNECTIONS = 1000
137137
private[this] val executor = Executors.newFixedThreadPool(MAX_AVAILABLE_GCS_CONNECTIONS)
138138

139-
override def shouldCacheQueryInfo: Boolean = false
140-
141139
def defaultParallelism: Int = 4
142140

143141
def broadcast[T: ClassTag](_value: T): BroadcastValue[T] = {
@@ -432,7 +430,8 @@ class ServiceBackend(
432430
serviceBackendContext,
433431
new IrMetadata(),
434432
ImmutableMap.empty,
435-
mutable.Map.empty,
433+
ImmutableMap.empty,
434+
ImmutableMap.empty,
436435
ImmutableMap.empty,
437436
)(f)
438437
}

hail/hail/src/is/hail/backend/spark/SparkBackend.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import is.hail.backend._
77
import is.hail.backend.py4j.Py4JBackendExtensions
88
import is.hail.expr.Validate
99
import is.hail.expr.ir._
10+
import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer
1011
import is.hail.expr.ir.analyses.SemanticHash
1112
import is.hail.expr.ir.compile.Compile
1213
import is.hail.expr.ir.defs.MakeTuple
@@ -356,6 +357,7 @@ class SparkBackend(
356357
private[this] val bmCache = mutable.Map.empty[String, BlockMatrix]
357358
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)
358359
private[this] val persistedIr = mutable.Map.empty[Int, BaseIR]
360+
private[this] val coercerCache = new Cache[Any, LoweredTableReaderCoercer](32)
359361

360362
def createExecuteContextForTests(
361363
timer: ExecutionTimer,
@@ -381,6 +383,7 @@ class SparkBackend(
381383
ImmutableMap.empty,
382384
ImmutableMap.empty,
383385
ImmutableMap.empty,
386+
ImmutableMap.empty,
384387
)
385388

386389
override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T =
@@ -403,6 +406,7 @@ class SparkBackend(
403406
bmCache,
404407
codeCache,
405408
persistedIr,
409+
coercerCache,
406410
)(f)
407411
}
408412

hail/hail/src/is/hail/expr/ir/GenericTableValue.scala

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package is.hail.expr.ir
33
import is.hail.annotations.Region
44
import is.hail.asm4s._
55
import is.hail.backend.ExecuteContext
6+
import is.hail.expr.ir.LoweredTableReader.LoweredTableReaderCoercer
67
import is.hail.expr.ir.defs.{Literal, PartitionReader, ReadPartition, ToStream}
78
import is.hail.expr.ir.functions.UtilFunctions
89
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency}
@@ -144,16 +145,6 @@ class PartitionIteratorLongReader(
144145
)
145146
}
146147

147-
abstract class LoweredTableReaderCoercer {
148-
def coerce(
149-
ctx: ExecuteContext,
150-
globals: IR,
151-
contextType: Type,
152-
contexts: IndexedSeq[Any],
153-
body: IR => IR,
154-
): TableStage
155-
}
156-
157148
class GenericTableValue(
158149
val fullTableType: TableType,
159150
val uidFieldName: String,
@@ -169,12 +160,11 @@ class GenericTableValue(
169160
assert(contextType.hasField("partitionIndex"))
170161
assert(contextType.fieldType("partitionIndex") == TInt32)
171162

172-
private var ltrCoercer: LoweredTableReaderCoercer = _
173-
174163
private def getLTVCoercer(ctx: ExecuteContext, context: String, cacheKey: Any)
175-
: LoweredTableReaderCoercer = {
176-
if (ltrCoercer == null) {
177-
ltrCoercer = LoweredTableReader.makeCoercer(
164+
: LoweredTableReaderCoercer =
165+
ctx.PersistedCoercerCache.getOrElseUpdate(
166+
(1, contextType, fullTableType.key, cacheKey),
167+
LoweredTableReader.makeCoercer(
178168
ctx,
179169
fullTableType.key,
180170
1,
@@ -185,11 +175,8 @@ class GenericTableValue(
185175
bodyPType,
186176
body,
187177
context,
188-
cacheKey,
189-
)
190-
}
191-
ltrCoercer
192-
}
178+
),
179+
)
193180

194181
def toTableStage(ctx: ExecuteContext, requestedType: TableType, context: String, cacheKey: Any)
195182
: TableStage = {
@@ -218,11 +205,13 @@ class GenericTableValue(
218205
val contextsIR = ToStream(Literal(TArray(contextType), contexts))
219206
TableStage(globalsIR, p, TableStageDependency.none, contextsIR, requestedBody)
220207
} else {
221-
getLTVCoercer(ctx, context, cacheKey).coerce(
208+
getLTVCoercer(ctx, context, cacheKey)(
222209
ctx,
223210
globalsIR,
224-
contextType, contexts,
225-
requestedBody)
211+
contextType,
212+
contexts,
213+
requestedBody,
214+
)
226215
}
227216
}
228217
}

0 commit comments

Comments
 (0)