Skip to content

Commit

Permalink
[query] Remove persistedIr from Backend interface
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Jan 22, 2025
1 parent a07503f commit 9aecc1d
Show file tree
Hide file tree
Showing 12 changed files with 466 additions and 499 deletions.
30 changes: 13 additions & 17 deletions hail/src/main/scala/is/hail/backend/Backend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@ package is.hail.backend
import is.hail.asm4s._
import is.hail.backend.Backend.jsonToBytes
import is.hail.backend.spark.SparkBackend
import is.hail.expr.ir.{
BaseIR, IR, IRParser, IRParserEnvironment, LoweringAnalyses, SortField, TableIR, TableReader,
}
import is.hail.expr.ir.{IR, IRParser, LoweringAnalyses, SortField, TableIR, TableReader}
import is.hail.expr.ir.lowering.{TableStage, TableStageDependency}
import is.hail.io.{BufferSpec, TypedCodecSpec}
import is.hail.io.fs._
Expand All @@ -18,7 +16,6 @@ import is.hail.types.virtual.TFloat64
import is.hail.utils._
import is.hail.variant.ReferenceGenome

import scala.collection.mutable
import scala.reflect.ClassTag

import java.io._
Expand Down Expand Up @@ -67,7 +64,6 @@ trait BackendContext {
}

abstract class Backend extends Closeable {
val persistedIR: mutable.Map[Int, BaseIR] = mutable.Map()

def defaultParallelism: Int

Expand Down Expand Up @@ -125,30 +121,30 @@ abstract class Backend extends Closeable {
def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T

final def valueType(s: String): Array[Byte] =
jsonToBytes {
withExecuteContext { ctx =>
IRParser.parse_value_ir(s, IRParserEnvironment(ctx, persistedIR.toMap)).typ.toJSON
withExecuteContext { ctx =>
jsonToBytes {
IRParser.parse_value_ir(ctx, s).typ.toJSON
}
}

final def tableType(s: String): Array[Byte] =
jsonToBytes {
withExecuteContext { ctx =>
IRParser.parse_table_ir(s, IRParserEnvironment(ctx, persistedIR.toMap)).typ.toJSON
withExecuteContext { ctx =>
jsonToBytes {
IRParser.parse_table_ir(ctx, s).typ.toJSON
}
}

final def matrixTableType(s: String): Array[Byte] =
jsonToBytes {
withExecuteContext { ctx =>
IRParser.parse_matrix_ir(s, IRParserEnvironment(ctx, persistedIR.toMap)).typ.toJSON
withExecuteContext { ctx =>
jsonToBytes {
IRParser.parse_matrix_ir(ctx, s).typ.toJSON
}
}

final def blockMatrixType(s: String): Array[Byte] =
jsonToBytes {
withExecuteContext { ctx =>
IRParser.parse_blockmatrix_ir(s, IRParserEnvironment(ctx, persistedIR.toMap)).typ.toJSON
withExecuteContext { ctx =>
jsonToBytes {
IRParser.parse_blockmatrix_ir(ctx, s).typ.toJSON
}
}

Expand Down
7 changes: 2 additions & 5 deletions hail/src/main/scala/is/hail/backend/BackendServer.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package is.hail.backend

import is.hail.expr.ir.{IRParser, IRParserEnvironment}
import is.hail.expr.ir.IRParser
import is.hail.utils._

import scala.util.control.NonFatal
Expand Down Expand Up @@ -89,10 +89,7 @@ class BackendHttpHandler(backend: Backend) extends HttpHandler {
backend.withExecuteContext { ctx =>
val (res, timings) = ExecutionTimer.time { timer =>
ctx.local(timer = timer) { ctx =>
val irData = IRParser.parse_value_ir(
irStr,
IRParserEnvironment(ctx, irMap = backend.persistedIR.toMap),
)
val irData = IRParser.parse_value_ir(ctx, irStr)
backend.execute(ctx, irData)
}
}
Expand Down
7 changes: 6 additions & 1 deletion hail/src/main/scala/is/hail/backend/ExecuteContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import is.hail.{HailContext, HailFeatureFlags}
import is.hail.annotations.{Region, RegionPool}
import is.hail.asm4s.HailClassLoader
import is.hail.backend.local.LocalTaskContext
import is.hail.expr.ir.{CodeCacheKey, CompiledFunction}
import is.hail.expr.ir.{BaseIR, CodeCacheKey, CompiledFunction}
import is.hail.expr.ir.lowering.IrMetadata
import is.hail.io.fs.FS
import is.hail.linalg.BlockMatrix
Expand Down Expand Up @@ -75,6 +75,7 @@ object ExecuteContext {
irMetadata: IrMetadata,
blockMatrixCache: mutable.Map[String, BlockMatrix],
codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]],
irCache: mutable.Map[Int, BaseIR],
)(
f: ExecuteContext => T
): T = {
Expand All @@ -95,6 +96,7 @@ object ExecuteContext {
irMetadata,
blockMatrixCache,
codeCache,
irCache,
))(f(_))
}
}
Expand Down Expand Up @@ -126,6 +128,7 @@ class ExecuteContext(
val irMetadata: IrMetadata,
val BlockMatrixCache: mutable.Map[String, BlockMatrix],
val CodeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]],
val IrCache: mutable.Map[Int, BaseIR],
) extends Closeable {

val rngNonce: Long =
Expand Down Expand Up @@ -194,6 +197,7 @@ class ExecuteContext(
irMetadata: IrMetadata = this.irMetadata,
blockMatrixCache: mutable.Map[String, BlockMatrix] = this.BlockMatrixCache,
codeCache: mutable.Map[CodeCacheKey, CompiledFunction[_]] = this.CodeCache,
irCache: mutable.Map[Int, BaseIR] = this.IrCache,
)(
f: ExecuteContext => A
): A =
Expand All @@ -212,5 +216,6 @@ class ExecuteContext(
irMetadata,
blockMatrixCache,
codeCache,
irCache,
))(f)
}
2 changes: 2 additions & 0 deletions hail/src/main/scala/is/hail/backend/local/LocalBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ class LocalBackend(

private[this] val theHailClassLoader = new HailClassLoader(getClass.getClassLoader)
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)
private[this] val persistedIR: mutable.Map[Int, BaseIR] = mutable.Map()

// flags can be set after construction from python
def fs: FS = RouterFS.buildRoutes(CloudStorageFSConfig.fromFlagsAndEnv(None, flags))
Expand All @@ -116,6 +117,7 @@ class LocalBackend(
new IrMetadata(),
ImmutableMap.empty,
codeCache,
persistedIR,
)(f)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ import is.hail.HailFeatureFlags
import is.hail.backend.{Backend, ExecuteContext, NonOwningTempFileManager, TempFileManager}
import is.hail.expr.{JSONAnnotationImpex, SparkAnnotationImpex}
import is.hail.expr.ir.{
BaseIR, BindingEnv, BlockMatrixIR, EncodedLiteral, GetFieldByIdx, IR, IRParser,
IRParserEnvironment, Interpret, MatrixIR, MatrixNativeReader, MatrixRead, Name,
NativeReaderOptions, TableIR, TableLiteral, TableValue,
BaseIR, BindingEnv, BlockMatrixIR, EncodedLiteral, GetFieldByIdx, IR, IRParser, Interpret,
MatrixIR, MatrixNativeReader, MatrixRead, Name, NativeReaderOptions, TableIR, TableLiteral,
TableValue,
}
import is.hail.expr.ir.IRParser.parseType
import is.hail.expr.ir.functions.IRFunctionRegistry
Expand Down Expand Up @@ -34,7 +34,6 @@ import sourcecode.Enclosing
trait Py4JBackendExtensions {
def backend: Backend
def references: mutable.Map[String, ReferenceGenome]
def persistedIR: mutable.Map[Int, BaseIR]
def flags: HailFeatureFlags
def longLifeTempFileManager: TempFileManager

Expand All @@ -54,14 +53,14 @@ trait Py4JBackendExtensions {
irID
}

private[this] def addJavaIR(ir: BaseIR): Int = {
private[this] def addJavaIR(ctx: ExecuteContext, ir: BaseIR): Int = {
val id = nextIRID()
persistedIR += (id -> ir)
ctx.IrCache += (id -> ir)
id
}

def pyRemoveJavaIR(id: Int): Unit =
persistedIR.remove(id)
backend.withExecuteContext(_.IrCache.remove(id))

def pyAddSequence(name: String, fastaFile: String, indexFile: String): Unit =
backend.withExecuteContext { ctx =>
Expand Down Expand Up @@ -118,7 +117,7 @@ trait Py4JBackendExtensions {
argTypeStrs: java.util.ArrayList[String],
returnType: String,
bodyStr: String,
): Unit = {
): Unit =
backend.withExecuteContext { ctx =>
IRFunctionRegistry.registerIR(
ctx,
Expand All @@ -130,17 +129,16 @@ trait Py4JBackendExtensions {
bodyStr,
)
}
}

def pyExecuteLiteral(irStr: String): Int =
backend.withExecuteContext { ctx =>
val ir = IRParser.parse_value_ir(irStr, IRParserEnvironment(ctx, persistedIR.toMap))
val ir = IRParser.parse_value_ir(ctx, irStr)
assert(ir.typ.isRealizable)
backend.execute(ctx, ir) match {
case Left(_) => throw new HailException("Can't create literal")
case Right((pt, addr)) =>
val field = GetFieldByIdx(EncodedLiteral.fromPTypeAndAddress(pt, addr, ctx), 0)
addJavaIR(field)
addJavaIR(ctx, field)
}
}

Expand All @@ -159,14 +157,14 @@ trait Py4JBackendExtensions {
),
ctx.theHailClassLoader,
)
val id = addJavaIR(tir)
val id = addJavaIR(ctx, tir)
(id, JsonMethods.compact(tir.typ.toJSON))
}
}

def pyToDF(s: String): DataFrame =
backend.withExecuteContext { ctx =>
val tir = IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap))
val tir = IRParser.parse_table_ir(ctx, s)
Interpret(tir, ctx).toDF()
}

Expand Down Expand Up @@ -219,27 +217,23 @@ trait Py4JBackendExtensions {
def parse_value_ir(s: String, refMap: java.util.Map[String, String]): IR =
backend.withExecuteContext { ctx =>
IRParser.parse_value_ir(
ctx,
s,
IRParserEnvironment(ctx, irMap = persistedIR.toMap),
BindingEnv.eval(refMap.asScala.toMap.map { case (n, t) =>
Name(n) -> IRParser.parseType(t)
}.toSeq: _*),
)
}

def parse_table_ir(s: String): TableIR =
withExecuteContext(selfContainedExecution = false) { ctx =>
IRParser.parse_table_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap))
}
withExecuteContext(selfContainedExecution = false)(ctx => IRParser.parse_table_ir(ctx, s))

def parse_matrix_ir(s: String): MatrixIR =
withExecuteContext(selfContainedExecution = false) { ctx =>
IRParser.parse_matrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap))
}
withExecuteContext(selfContainedExecution = false)(ctx => IRParser.parse_matrix_ir(ctx, s))

def parse_blockmatrix_ir(s: String): BlockMatrixIR =
withExecuteContext(selfContainedExecution = false) { ctx =>
IRParser.parse_blockmatrix_ir(s, IRParserEnvironment(ctx, irMap = persistedIR.toMap))
IRParser.parse_blockmatrix_ir(ctx, s)
}

def loadReferencesFromDataset(path: String): Array[Byte] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ class ServiceBackend(
new IrMetadata(),
ImmutableMap.empty,
mutable.Map.empty,
ImmutableMap.empty,
)(f)
}

Expand Down
3 changes: 3 additions & 0 deletions hail/src/main/scala/is/hail/backend/spark/SparkBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ class SparkBackend(

private[this] val bmCache = new BlockMatrixCache()
private[this] val codeCache = new Cache[CodeCacheKey, CompiledFunction[_]](50)
private[this] val persistedIr = mutable.Map.empty[Int, BaseIR]

def createExecuteContextForTests(
timer: ExecutionTimer,
Expand All @@ -378,6 +379,7 @@ class SparkBackend(
new IrMetadata(),
ImmutableMap.empty,
mutable.Map.empty,
ImmutableMap.empty,
)

override def withExecuteContext[T](f: ExecuteContext => T)(implicit E: Enclosing): T =
Expand All @@ -399,6 +401,7 @@ class SparkBackend(
new IrMetadata(),
bmCache,
codeCache,
persistedIr,
)(f)
}

Expand Down
12 changes: 6 additions & 6 deletions hail/src/main/scala/is/hail/expr/ir/MatrixIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,14 @@ case class MatrixLiteral(typ: MatrixType, tl: TableLiteral) extends MatrixIR {
}

object MatrixReader {
def fromJson(env: IRParserEnvironment, jv: JValue): MatrixReader = {
def fromJson(ctx: ExecuteContext, jv: JValue): MatrixReader = {
implicit val formats: Formats = DefaultFormats
(jv \ "name").extract[String] match {
case "MatrixRangeReader" => MatrixRangeReader.fromJValue(env.ctx, jv)
case "MatrixNativeReader" => MatrixNativeReader.fromJValue(env.ctx.fs, jv)
case "MatrixBGENReader" => MatrixBGENReader.fromJValue(env, jv)
case "MatrixPLINKReader" => MatrixPLINKReader.fromJValue(env.ctx, jv)
case "MatrixVCFReader" => MatrixVCFReader.fromJValue(env.ctx, jv)
case "MatrixRangeReader" => MatrixRangeReader.fromJValue(ctx, jv)
case "MatrixNativeReader" => MatrixNativeReader.fromJValue(ctx.fs, jv)
case "MatrixBGENReader" => MatrixBGENReader.fromJValue(ctx, jv)
case "MatrixPLINKReader" => MatrixPLINKReader.fromJValue(ctx, jv)
case "MatrixVCFReader" => MatrixVCFReader.fromJValue(ctx, jv)
}
}

Expand Down
Loading

0 comments on commit 9aecc1d

Please sign in to comment.