Skip to content

Commit

Permalink
[query] Remove lookupOrCompileCachedFunction from Backend interface
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Dec 17, 2024
1 parent 0d9d8f3 commit a99150b
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 133 deletions.
26 changes: 1 addition & 25 deletions hail/src/main/scala/is/hail/backend/Backend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ import is.hail.asm4s._
import is.hail.backend.Backend.jsonToBytes
import is.hail.backend.spark.SparkBackend
import is.hail.expr.ir.{
BaseIR, CodeCacheKey, CompiledFunction, IR, IRParser, IRParserEnvironment, LoweringAnalyses,
SortField, TableIR, TableReader,
BaseIR, IR, IRParser, IRParserEnvironment, LoweringAnalyses, SortField, TableIR, TableReader,
}
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency}
import is.hail.io.{BufferSpec, TypedCodecSpec}
Expand Down Expand Up @@ -92,9 +91,6 @@ abstract class Backend extends Closeable {

def shouldCacheQueryInfo: Boolean = true

def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T])
: CompiledFunction[T]

def lowerDistributedSort(
ctx: ExecuteContext,
stage: TableStage,
Expand Down Expand Up @@ -193,23 +189,3 @@ abstract class Backend extends Closeable {

def execute(ctx: ExecuteContext, ir: IR): Either[Unit, (PTuple, Long)]
}

trait BackendWithCodeCache {
private[this] val codeCache: Cache[CodeCacheKey, CompiledFunction[_]] = new Cache(50)

def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T])
: CompiledFunction[T] = {
codeCache.get(k) match {
case Some(v) => v.asInstanceOf[CompiledFunction[T]]
case None =>
val compiledFunction = f
codeCache += ((k, compiledFunction))
compiledFunction
}
}
}

trait BackendWithNoCodeCache {
def lookupOrCompileCachedFunction[T](k: CodeCacheKey)(f: => CompiledFunction[T])
: CompiledFunction[T] = f
}
6 changes: 6 additions & 0 deletions hail/src/main/scala/is/hail/backend/ExecuteContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import is.hail.{HailContext, HailFeatureFlags}
import is.hail.annotations.{Region, RegionPool}
import is.hail.asm4s.HailClassLoader
import is.hail.backend.local.LocalTaskContext
import is.hail.expr.ir.{CodeCacheKey, CompiledFunction}
import is.hail.expr.ir.lowering.IrMetadata
import is.hail.io.fs.FS
import is.hail.linalg.BlockMatrix
Expand Down Expand Up @@ -73,6 +74,7 @@ object ExecuteContext {
backendContext: BackendContext,
irMetadata: IrMetadata,
blockMatrixCache: mutable.Map[String, BlockMatrix],
codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]],
)(
f: ExecuteContext => T
): T = {
Expand All @@ -92,6 +94,7 @@ object ExecuteContext {
backendContext,
irMetadata,
blockMatrixCache,
codeCache,
))(f(_))
}
}
Expand Down Expand Up @@ -122,6 +125,7 @@ class ExecuteContext(
val backendContext: BackendContext,
val irMetadata: IrMetadata,
val BlockMatrixCache: mutable.Map[String, BlockMatrix],
val CodeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]],
) extends Closeable {

val rngNonce: Long =
Expand Down Expand Up @@ -191,6 +195,7 @@ class ExecuteContext(
backendContext: BackendContext = this.backendContext,
irMetadata: IrMetadata = this.irMetadata,
blockMatrixCache: mutable.Map[String, BlockMatrix] = this.BlockMatrixCache,
codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]] = this.CodeCache,
)(
f: ExecuteContext => A
): A =
Expand All @@ -208,5 +213,6 @@ class ExecuteContext(
backendContext,
irMetadata,
blockMatrixCache,
codeCache,
))(f)
}
6 changes: 4 additions & 2 deletions hail/src/main/scala/is/hail/backend/local/LocalBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,14 @@ object LocalBackend {
class LocalBackend(
val tmpdir: String,
override val references: mutable.Map[String, ReferenceGenome],
) extends Backend with BackendWithCodeCache with Py4JBackendExtensions {
) extends Backend with Py4JBackendExtensions {

override def backend: Backend = this
override val flags: HailFeatureFlags = HailFeatureFlags.fromEnv()
override def longLifeTempFileManager: TempFileManager = null

private[this] val theHailClassLoader = new HailClassLoader(getClass().getClassLoader())
private[this] val theHailClassLoader = new HailClassLoader(getClass.getClassLoader)
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)

// flags can be set after construction from python
def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags))
Expand All @@ -113,6 +114,7 @@ class LocalBackend(
},
new IrMetadata(),
ImmutableMap.empty,
codeCache,
)(f)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ class ServiceBackendContext(
) extends BackendContext with Serializable {}

object ServiceBackend {
private val log = Logger.getLogger(getClass.getName())

def apply(
jarLocation: String,
Expand Down Expand Up @@ -130,8 +129,7 @@ class ServiceBackend(
val fs: FS,
val serviceBackendContext: ServiceBackendContext,
val scratchDir: String,
) extends Backend with BackendWithNoCodeCache {
import ServiceBackend.log
) extends Backend with Logging {

private[this] var stageCount = 0
private[this] val MAX_AVAILABLE_GCS_CONNECTIONS = 1000
Expand Down Expand Up @@ -393,6 +391,7 @@ class ServiceBackend(
serviceBackendContext,
new IrMetadata(),
ImmutableMap.empty,
mutable.Map.empty,
)(f)
}

Expand Down
8 changes: 5 additions & 3 deletions hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ class SparkBackend(
override val references: mutable.Map[String, ReferenceGenome],
gcsRequesterPaysProject: String,
gcsRequesterPaysBuckets: String,
) extends Backend with BackendWithCodeCache with Py4JBackendExtensions {
) extends Backend with Py4JBackendExtensions {

assert(gcsRequesterPaysProject != null || gcsRequesterPaysBuckets == null)
lazy val sparkSession: SparkSession = SparkSession.builder().config(sc.getConf).getOrCreate()
Expand Down Expand Up @@ -351,8 +351,8 @@ class SparkBackend(
override val longLifeTempFileManager: TempFileManager =
new OwningTempFileManager(fs)

private[this] val bmCache: BlockMatrixCache =
new BlockMatrixCache()
private[this] val bmCache = new BlockMatrixCache()
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)

def createExecuteContextForTests(
timer: ExecutionTimer,
Expand All @@ -376,6 +376,7 @@ class SparkBackend(
},
new IrMetadata(),
ImmutableMap.empty,
mutable.Map.empty,
)

override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T =
Expand All @@ -396,6 +397,7 @@ class SparkBackend(
},
new IrMetadata(),
bmCache,
codeCache,
)(f)
}

Expand Down
177 changes: 81 additions & 96 deletions hail/src/main/scala/is/hail/expr/ir/Compile.scala
Original file line number Diff line number Diff line change
Expand Up @@ -44,53 +44,49 @@ object Compile {
): (Option[SingleCodeType], (HailClassLoader, FS, HailTaskContext, Region) => F) =
ctx.time {
val normalizedBody = NormalizeNames(ctx, body, allowFreeVariables = true)
val k =
CodeCacheKey(FastSeq[AggStateSig](), params.map { case (n, pt) => (n, pt) }, normalizedBody)
(ctx.backend.lookupOrCompileCachedFunction[F](k) {

var ir = body
ir = Subst(
ir,
BindingEnv(params
.zipWithIndex
.foldLeft(Env.empty[IR]) { case (e, ((n, t), i)) => e.bind(n, In(i, t)) }),
)
ir =
LoweringPipeline.compileLowerer(optimize).apply(ctx, ir).asInstanceOf[IR].noSharing(ctx)

TypeCheck(ctx, ir, BindingEnv.empty)

val returnParam = CodeParamType(SingleCodeType.typeInfoFromType(ir.typ))

val fb = EmitFunctionBuilder[F](
ctx,
"Compiled",
CodeParamType(typeInfo[Region]) +: params.map { case (_, pt) =>
pt
},
returnParam,
Some("Emit.scala"),
)

/* { def visit(x: IR): Unit = { println(f"${ System.identityHashCode(x) }%08x ${
* x.getClass.getSimpleName } ${ x.pType }") Children(x).foreach { case c: IR => visit(c) }
* }
*
* visit(ir) } */

assert(
fb.mb.parameterTypeInfo == expectedCodeParamTypes,
s"expected $expectedCodeParamTypes, got ${fb.mb.parameterTypeInfo}",
)
assert(
fb.mb.returnTypeInfo == expectedCodeReturnType,
s"expected $expectedCodeReturnType, got ${fb.mb.returnTypeInfo}",
)

val emitContext = EmitContext.analyze(ctx, ir)
val rt = Emit(emitContext, ir, fb, expectedCodeReturnType, params.length)
CompiledFunction(rt, fb.resultWithIndex(print))
}).tuple
ctx.CodeCache.getOrElseUpdate(
CodeCacheKey(FastSeq(), params.map { case (n, pt) => (n, pt) }, normalizedBody), {
var ir = body
ir = Subst(
ir,
BindingEnv(params
.zipWithIndex
.foldLeft(Env.empty[IR]) { case (e, ((n, t), i)) => e.bind(n, In(i, t)) }),
)
ir = LoweringPipeline.compileLowerer(optimize)(ctx, ir).asInstanceOf[IR].noSharing(ctx)

TypeCheck(ctx, ir)

val fb = EmitFunctionBuilder[F](
ctx,
"Compiled",
CodeParamType(typeInfo[Region]) +: params.map { case (_, pt) =>
pt
},
CodeParamType(SingleCodeType.typeInfoFromType(ir.typ)),
Some("Emit.scala"),
)

/* { def visit(x: IR): Unit = { println(f"${ System.identityHashCode(x) }%08x ${
* x.getClass.getSimpleName } ${ x.pType }") Children(x).foreach { case c: IR => visit(c)
* } }
*
* visit(ir) } */

assert(
fb.mb.parameterTypeInfo == expectedCodeParamTypes,
s"expected $expectedCodeParamTypes, got ${fb.mb.parameterTypeInfo}",
)
assert(
fb.mb.returnTypeInfo == expectedCodeReturnType,
s"expected $expectedCodeReturnType, got ${fb.mb.returnTypeInfo}",
)

val emitContext = EmitContext.analyze(ctx, ir)
val rt = Emit(emitContext, ir, fb, expectedCodeReturnType, params.length)
CompiledFunction(rt, fb.resultWithIndex(print))
},
).asInstanceOf[CompiledFunction[F]].tuple
}
}

Expand All @@ -108,55 +104,44 @@ object CompileWithAggregators {
(HailClassLoader, FS, HailTaskContext, Region) => (F with FunctionWithAggRegion),
) =
ctx.time {
val normalizedBody =
NormalizeNames(ctx, body, allowFreeVariables = true)
val k = CodeCacheKey(aggSigs, params.map { case (n, pt) => (n, pt) }, normalizedBody)
(ctx.backend.lookupOrCompileCachedFunction[F with FunctionWithAggRegion](k) {

var ir = body
ir = Subst(
ir,
BindingEnv(params
.zipWithIndex
.foldLeft(Env.empty[IR]) { case (e, ((n, t), i)) => e.bind(n, In(i, t)) }),
)
ir =
LoweringPipeline.compileLowerer(optimize).apply(ctx, ir).asInstanceOf[IR].noSharing(ctx)

TypeCheck(
ctx,
ir,
BindingEnv(Env.fromSeq[Type](params.map { case (name, t) => name -> t.virtualType })),
)

val fb = EmitFunctionBuilder[F](
ctx,
"CompiledWithAggs",
CodeParamType(typeInfo[Region]) +: params.map { case (_, pt) => pt },
SingleCodeType.typeInfoFromType(ir.typ),
Some("Emit.scala"),
)

/* { def visit(x: IR): Unit = { println(f"${ System.identityHashCode(x) }%08x ${
* x.getClass.getSimpleName } ${ x.pType }") Children(x).foreach { case c: IR => visit(c) }
* }
*
* visit(ir) } */

val emitContext = EmitContext.analyze(ctx, ir)
val rt = Emit(emitContext, ir, fb, expectedCodeReturnType, params.length, Some(aggSigs))

val f = fb.resultWithIndex()
CompiledFunction(
rt,
f.asInstanceOf[(
HailClassLoader,
FS,
HailTaskContext,
Region,
) => (F with FunctionWithAggRegion)],
)
}).tuple
val normalizedBody = NormalizeNames(ctx, body, allowFreeVariables = true)
ctx.CodeCache.getOrElseUpdate(
CodeCacheKey(aggSigs, params.map { case (n, pt) => (n, pt) }, normalizedBody), {
var ir = body
ir = Subst(
ir,
BindingEnv(params
.zipWithIndex
.foldLeft(Env.empty[IR]) { case (e, ((n, t), i)) => e.bind(n, In(i, t)) }),
)
ir =
LoweringPipeline.compileLowerer(optimize).apply(ctx, ir).asInstanceOf[IR].noSharing(ctx)

TypeCheck(
ctx,
ir,
BindingEnv(Env.fromSeq[Type](params.map { case (name, t) => name -> t.virtualType })),
)

val fb = EmitFunctionBuilder[F with FunctionWithAggRegion](
ctx,
"CompiledWithAggs",
CodeParamType(typeInfo[Region]) +: params.map { case (_, pt) => pt },
SingleCodeType.typeInfoFromType(ir.typ),
Some("Emit.scala"),
)

/* { def visit(x: IR): Unit = { println(f"${ System.identityHashCode(x) }%08x ${
* x.getClass.getSimpleName } ${ x.pType }") Children(x).foreach { case c: IR => visit(c)
* } }
*
* visit(ir) } */

val emitContext = EmitContext.analyze(ctx, ir)
val rt = Emit(emitContext, ir, fb, expectedCodeReturnType, params.length, Some(aggSigs))
CompiledFunction(rt, fb.resultWithIndex())
},
).asInstanceOf[CompiledFunction[F with FunctionWithAggRegion]].tuple
}
}

Expand Down
Loading

0 comments on commit a99150b

Please sign in to comment.