Skip to content

Commit

Permalink
reduce code duplication
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Dec 17, 2024
1 parent a99150b commit b4ad140
Show file tree
Hide file tree
Showing 15 changed files with 94 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import is.hail.backend.py4j.Py4JBackendExtensions
import is.hail.expr.Validate
import is.hail.expr.ir._
import is.hail.expr.ir.analyses.SemanticHash
import is.hail.expr.ir.compile.Compile
import is.hail.expr.ir.lowering._
import is.hail.io.fs._
import is.hail.types._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ import is.hail.asm4s._
import is.hail.backend._
import is.hail.expr.Validate
import is.hail.expr.ir.{
Compile, IR, IRParser, IRSize, LoweringAnalyses, MakeTuple, SortField, TableIR, TableReader,
TypeCheck,
IR, IRParser, IRSize, LoweringAnalyses, MakeTuple, SortField, TableIR, TableReader, TypeCheck,
}
import is.hail.expr.ir.analyses.SemanticHash
import is.hail.expr.ir.compile.Compile
import is.hail.expr.ir.functions.IRFunctionRegistry
import is.hail.expr.ir.lowering._
import is.hail.io.fs._
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import is.hail.backend.py4j.Py4JBackendExtensions
import is.hail.expr.Validate
import is.hail.expr.ir._
import is.hail.expr.ir.analyses.SemanticHash
import is.hail.expr.ir.compile.Compile
import is.hail.expr.ir.lowering._
import is.hail.io.{BufferSpec, TypedCodecSpec}
import is.hail.io.fs._
Expand Down
133 changes: 61 additions & 72 deletions hail/src/main/scala/is/hail/expr/ir/Compile.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ import is.hail.types.physical.stypes.{
PTypeReferenceSingleCodeType, SingleCodeType, StreamSingleCodeType,
}
import is.hail.types.physical.stypes.interfaces.{NoBoxLongIterator, SStream}
import is.hail.types.virtual.Type
import is.hail.utils._

import java.io.PrintWriter

import sourcecode.Enclosing

case class CodeCacheKey(
aggSigs: IndexedSeq[AggStateSig],
args: Seq[(Name, EmitParamType)],
Expand All @@ -32,8 +33,9 @@ case class CompiledFunction[T](
(typ, f)
}

object Compile {
def apply[F: TypeInfo](
object compile {

def Compile[F: TypeInfo](
ctx: ExecuteContext,
params: IndexedSeq[(Name, EmitParamType)],
expectedCodeParamTypes: IndexedSeq[TypeInfo[_]],
Expand All @@ -42,27 +44,69 @@ object Compile {
optimize: Boolean = true,
print: Option[PrintWriter] = None,
): (Option[SingleCodeType], (HailClassLoader, FS, HailTaskContext, Region) => F) =
Impl[F, AnyVal](
ctx,
params,
None,
expectedCodeParamTypes,
expectedCodeReturnType,
body,
optimize,
print,
)

def CompileWithAggregators[F: TypeInfo](
ctx: ExecuteContext,
aggSigs: Array[AggStateSig],
params: IndexedSeq[(Name, EmitParamType)],
expectedCodeParamTypes: IndexedSeq[TypeInfo[_]],
expectedCodeReturnType: TypeInfo[_],
body: IR,
optimize: Boolean = true,
print: Option[PrintWriter] = None,
): (
Option[SingleCodeType],
(HailClassLoader, FS, HailTaskContext, Region) => F with FunctionWithAggRegion,
) =
Impl[F, FunctionWithAggRegion](
ctx,
params,
Some(aggSigs),
expectedCodeParamTypes,
expectedCodeReturnType,
body,
optimize,
print,
)

private[this] def Impl[F: TypeInfo, Mixin](
ctx: ExecuteContext,
params: IndexedSeq[(Name, EmitParamType)],
aggSigs: Option[Array[AggStateSig]],
expectedCodeParamTypes: IndexedSeq[TypeInfo[_]],
expectedCodeReturnType: TypeInfo[_],
body: IR,
optimize: Boolean,
print: Option[PrintWriter],
)(implicit
E: Enclosing,
N: sourcecode.Name,
): (Option[SingleCodeType], (HailClassLoader, FS, HailTaskContext, Region) => F with Mixin) =
ctx.time {
val normalizedBody = NormalizeNames(ctx, body, allowFreeVariables = true)
ctx.CodeCache.getOrElseUpdate(
CodeCacheKey(FastSeq(), params.map { case (n, pt) => (n, pt) }, normalizedBody), {
var ir = body
ir = Subst(
ir,
BindingEnv(params
.zipWithIndex
.foldLeft(Env.empty[IR]) { case (e, ((n, t), i)) => e.bind(n, In(i, t)) }),
CodeCacheKey(aggSigs.getOrElse(Array.empty).toFastSeq, params, normalizedBody), {
var ir = Subst(
body,
BindingEnv(Env.fromSeq(params.zipWithIndex.map { case ((n, t), i) => n -> In(i, t) })),
)
ir = LoweringPipeline.compileLowerer(optimize)(ctx, ir).asInstanceOf[IR].noSharing(ctx)

TypeCheck(ctx, ir)

val fb = EmitFunctionBuilder[F](
ctx,
"Compiled",
CodeParamType(typeInfo[Region]) +: params.map { case (_, pt) =>
pt
},
N.value,
CodeParamType(typeInfo[Region]) +: params.map(_._2),
CodeParamType(SingleCodeType.typeInfoFromType(ir.typ)),
Some("Emit.scala"),
)
Expand All @@ -83,65 +127,10 @@ object Compile {
)

val emitContext = EmitContext.analyze(ctx, ir)
val rt = Emit(emitContext, ir, fb, expectedCodeReturnType, params.length)
val rt = Emit(emitContext, ir, fb, expectedCodeReturnType, params.length, aggSigs)
CompiledFunction(rt, fb.resultWithIndex(print))
},
).asInstanceOf[CompiledFunction[F]].tuple
}
}

object CompileWithAggregators {
def apply[F: TypeInfo](
ctx: ExecuteContext,
aggSigs: Array[AggStateSig],
params: IndexedSeq[(Name, EmitParamType)],
expectedCodeParamTypes: IndexedSeq[TypeInfo[_]],
expectedCodeReturnType: TypeInfo[_],
body: IR,
optimize: Boolean = true,
): (
Option[SingleCodeType],
(HailClassLoader, FS, HailTaskContext, Region) => (F with FunctionWithAggRegion),
) =
ctx.time {
val normalizedBody = NormalizeNames(ctx, body, allowFreeVariables = true)
ctx.CodeCache.getOrElseUpdate(
CodeCacheKey(aggSigs, params.map { case (n, pt) => (n, pt) }, normalizedBody), {
var ir = body
ir = Subst(
ir,
BindingEnv(params
.zipWithIndex
.foldLeft(Env.empty[IR]) { case (e, ((n, t), i)) => e.bind(n, In(i, t)) }),
)
ir =
LoweringPipeline.compileLowerer(optimize).apply(ctx, ir).asInstanceOf[IR].noSharing(ctx)

TypeCheck(
ctx,
ir,
BindingEnv(Env.fromSeq[Type](params.map { case (name, t) => name -> t.virtualType })),
)

val fb = EmitFunctionBuilder[F with FunctionWithAggRegion](
ctx,
"CompiledWithAggs",
CodeParamType(typeInfo[Region]) +: params.map { case (_, pt) => pt },
SingleCodeType.typeInfoFromType(ir.typ),
Some("Emit.scala"),
)

/* { def visit(x: IR): Unit = { println(f"${ System.identityHashCode(x) }%08x ${
* x.getClass.getSimpleName } ${ x.pType }") Children(x).foreach { case c: IR => visit(c)
* } }
*
* visit(ir) } */

val emitContext = EmitContext.analyze(ctx, ir)
val rt = Emit(emitContext, ir, fb, expectedCodeReturnType, params.length, Some(aggSigs))
CompiledFunction(rt, fb.resultWithIndex())
},
).asInstanceOf[CompiledFunction[F with FunctionWithAggRegion]].tuple
).asInstanceOf[CompiledFunction[F with Mixin]].tuple
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package is.hail.expr.ir
import is.hail.annotations.{Region, SafeRow}
import is.hail.asm4s._
import is.hail.backend.ExecuteContext
import is.hail.expr.ir.compile.Compile
import is.hail.expr.ir.lowering.LoweringPipeline
import is.hail.types.physical.PTuple
import is.hail.types.physical.stypes.PTypeReferenceSingleCodeType
Expand Down
1 change: 1 addition & 0 deletions hail/src/main/scala/is/hail/expr/ir/Emit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import is.hail.expr.ir.agg.{AggStateSig, ArrayAggStateSig, GroupedStateSig}
import is.hail.expr.ir.analyses.{
ComputeMethodSplits, ControlFlowPreventsSplit, ParentPointers, SemanticHash,
}
import is.hail.expr.ir.compile.Compile
import is.hail.expr.ir.lowering.TableStageDependency
import is.hail.expr.ir.ndarrays.EmitNDArray
import is.hail.expr.ir.streams.{EmitStream, StreamProducer, StreamUtils}
Expand Down
1 change: 1 addition & 0 deletions hail/src/main/scala/is/hail/expr/ir/Interpret.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import is.hail.annotations._
import is.hail.asm4s._
import is.hail.backend.{ExecuteContext, HailTaskContext}
import is.hail.backend.spark.SparkTaskContext
import is.hail.expr.ir.compile.{Compile, CompileWithAggregators}
import is.hail.expr.ir.lowering.LoweringPipeline
import is.hail.io.BufferSpec
import is.hail.linalg.BlockMatrix
Expand Down
28 changes: 14 additions & 14 deletions hail/src/main/scala/is/hail/expr/ir/TableIR.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import is.hail.annotations._
import is.hail.asm4s._
import is.hail.backend.{ExecuteContext, HailStateManager, HailTaskContext, TaskFinalizer}
import is.hail.backend.spark.{SparkBackend, SparkTaskContext}
import is.hail.expr.ir
import is.hail.expr.ir.compile.{Compile, CompileWithAggregators}
import is.hail.expr.ir.functions.{
BlockMatrixToTableFunction, IntervalFunctions, MatrixToTableFunction, TableToTableFunction,
}
Expand Down Expand Up @@ -1931,7 +1931,7 @@ case class TableNativeZippedReader(
val leftRef = Ref(freshName(), pLeft.virtualType)
val rightRef = Ref(freshName(), pRight.virtualType)
val (Some(PTypeReferenceSingleCodeType(t: PStruct)), mk) =
ir.Compile[AsmFunction3RegionLongLongLong](
Compile[AsmFunction3RegionLongLongLong](
ctx,
FastSeq(
leftRef.name -> SingleCodeEmitParamType(true, PTypeReferenceSingleCodeType(pLeft)),
Expand Down Expand Up @@ -2420,7 +2420,7 @@ case class TableFilter(child: TableIR, pred: IR) extends TableIR {
else if (pred == False())
return TableValueIntermediate(tv.copy(rvd = RVD.empty(ctx, typ.canonicalRVDType)))

val (Some(BooleanSingleCodeType), f) = ir.Compile[AsmFunction3RegionLongLongBoolean](
val (Some(BooleanSingleCodeType), f) = Compile[AsmFunction3RegionLongLongBoolean](
ctx,
FastSeq(
(
Expand Down Expand Up @@ -3035,7 +3035,7 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR {

if (extracted.aggs.isEmpty) {
val (Some(PTypeReferenceSingleCodeType(rTyp)), f) =
ir.Compile[AsmFunction3RegionLongLongLong](
Compile[AsmFunction3RegionLongLongLong](
ctx,
FastSeq(
(
Expand Down Expand Up @@ -3101,7 +3101,7 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR {
// 3. load in partition aggregations, comb op as necessary, serialize.
// 4. load in partStarts, calculate newRow based on those results.

val (_, initF) = ir.CompileWithAggregators[AsmFunction2RegionLongUnit](
val (_, initF) = CompileWithAggregators[AsmFunction2RegionLongUnit](
ctx,
extracted.states,
FastSeq((
Expand All @@ -3115,7 +3115,7 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR {

val serializeF = extracted.serialize(ctx, spec)

val (_, eltSeqF) = ir.CompileWithAggregators[AsmFunction3RegionLongLongUnit](
val (_, eltSeqF) = CompileWithAggregators[AsmFunction3RegionLongLongUnit](
ctx,
extracted.states,
FastSeq(
Expand All @@ -3138,7 +3138,7 @@ case class TableMapRows(child: TableIR, newRow: IR) extends TableIR {
val combOpFNeedsPool = extracted.combOpFSerializedFromRegionPool(ctx, spec)

val (Some(PTypeReferenceSingleCodeType(rTyp)), f) =
ir.CompileWithAggregators[AsmFunction3RegionLongLongLong](
CompileWithAggregators[AsmFunction3RegionLongLongLong](
ctx,
extracted.states,
FastSeq(
Expand Down Expand Up @@ -3697,7 +3697,7 @@ case class TableKeyByAndAggregate(

val localKeyType = keyType
val (Some(PTypeReferenceSingleCodeType(localKeyPType: PStruct)), makeKeyF) =
ir.Compile[AsmFunction3RegionLongLongLong](
Compile[AsmFunction3RegionLongLongLong](
ctx,
FastSeq(
(
Expand All @@ -3723,7 +3723,7 @@ case class TableKeyByAndAggregate(

val extracted = agg.Extract(expr, Requiredness(this, ctx))

val (_, makeInit) = ir.CompileWithAggregators[AsmFunction2RegionLongUnit](
val (_, makeInit) = CompileWithAggregators[AsmFunction2RegionLongUnit](
ctx,
extracted.states,
FastSeq((
Expand All @@ -3735,7 +3735,7 @@ case class TableKeyByAndAggregate(
extracted.init,
)

val (_, makeSeq) = ir.CompileWithAggregators[AsmFunction3RegionLongLongUnit](
val (_, makeSeq) = CompileWithAggregators[AsmFunction3RegionLongLongUnit](
ctx,
extracted.states,
FastSeq(
Expand All @@ -3754,7 +3754,7 @@ case class TableKeyByAndAggregate(
)

val (Some(PTypeReferenceSingleCodeType(rTyp: PStruct)), makeAnnotate) =
ir.CompileWithAggregators[AsmFunction2RegionLongLong](
CompileWithAggregators[AsmFunction2RegionLongLong](
ctx,
extracted.states,
FastSeq((
Expand Down Expand Up @@ -3897,7 +3897,7 @@ case class TableAggregateByKey(child: TableIR, expr: IR) extends TableIR {

val extracted = agg.Extract(expr, Requiredness(this, ctx))

val (_, makeInit) = ir.CompileWithAggregators[AsmFunction2RegionLongUnit](
val (_, makeInit) = CompileWithAggregators[AsmFunction2RegionLongUnit](
ctx,
extracted.states,
FastSeq((
Expand All @@ -3909,7 +3909,7 @@ case class TableAggregateByKey(child: TableIR, expr: IR) extends TableIR {
extracted.init,
)

val (_, makeSeq) = ir.CompileWithAggregators[AsmFunction3RegionLongLongUnit](
val (_, makeSeq) = CompileWithAggregators[AsmFunction3RegionLongLongUnit](
ctx,
extracted.states,
FastSeq(
Expand All @@ -3933,7 +3933,7 @@ case class TableAggregateByKey(child: TableIR, expr: IR) extends TableIR {
val key = Ref(freshName(), keyType.virtualType)
val value = Ref(freshName(), valueIR.typ)
val (Some(PTypeReferenceSingleCodeType(rowType: PStruct)), makeRow) =
ir.CompileWithAggregators[AsmFunction3RegionLongLongLong](
CompileWithAggregators[AsmFunction3RegionLongLongLong](
ctx,
extracted.states,
FastSeq(
Expand Down
7 changes: 4 additions & 3 deletions hail/src/main/scala/is/hail/expr/ir/agg/Extract.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import is.hail.backend.{ExecuteContext, HailTaskContext}
import is.hail.backend.spark.SparkTaskContext
import is.hail.expr.ir
import is.hail.expr.ir._
import is.hail.expr.ir.compile.CompileWithAggregators
import is.hail.io.BufferSpec
import is.hail.types.{TypeWithRequiredness, VirtualTypeWithReq}
import is.hail.types.physical.stypes.EmitType
Expand Down Expand Up @@ -247,7 +248,7 @@ class Aggs(

def deserialize(ctx: ExecuteContext, spec: BufferSpec)
: ((HailClassLoader, HailTaskContext, Region, Array[Byte]) => Long) = {
val (_, f) = ir.CompileWithAggregators[AsmFunction1RegionUnit](
val (_, f) = CompileWithAggregators[AsmFunction1RegionUnit](
ctx,
states,
FastSeq(),
Expand All @@ -268,7 +269,7 @@ class Aggs(

def serialize(ctx: ExecuteContext, spec: BufferSpec)
: (HailClassLoader, HailTaskContext, Region, Long) => Array[Byte] = {
val (_, f) = ir.CompileWithAggregators[AsmFunction1RegionUnit](
val (_, f) = CompileWithAggregators[AsmFunction1RegionUnit](
ctx,
states,
FastSeq(),
Expand Down Expand Up @@ -305,7 +306,7 @@ class Aggs(
: (() => (RegionPool, HailClassLoader, HailTaskContext)) => (
(Array[Byte], Array[Byte]) => Array[Byte],
) = {
val (_, f) = ir.CompileWithAggregators[AsmFunction1RegionUnit](
val (_, f) = CompileWithAggregators[AsmFunction1RegionUnit](
ctx,
states ++ states,
FastSeq(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import is.hail.annotations.{Annotation, ExtendedOrdering, Region, SafeRow}
import is.hail.asm4s.{classInfo, AsmFunction1RegionLong, LongInfo}
import is.hail.backend.{ExecuteContext, HailStateManager}
import is.hail.expr.ir._
import is.hail.expr.ir.compile.Compile
import is.hail.expr.ir.functions.{ArrayFunctions, IRRandomness, UtilFunctions}
import is.hail.io.{BufferSpec, TypedCodecSpec}
import is.hail.rvd.RVDPartitioner
Expand Down
Loading

0 comments on commit b4ad140

Please sign in to comment.