Skip to content

Commit

Permalink
only remove ir functions registered in that execute request
Browse files Browse the repository at this point in the history
  • Loading branch information
ehigham committed Dec 12, 2024
1 parent 3544a86 commit 59ad7bb
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 29 deletions.
10 changes: 4 additions & 6 deletions hail/python/hail/backend/service_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ class SequenceConfig:
@dataclass
class ServiceBackendRPCConfig:
tmp_dir: str
remote_tmpdir: str
flags: Dict[str, str]
custom_references: List[str]
liftovers: Dict[str, Dict[str, str]]
Expand Down Expand Up @@ -328,8 +327,8 @@ async def _run_on_batch(
elif self.driver_memory is not None:
resources['memory'] = str(self.driver_memory)

if service_backend_config.storage != '0Gi':
resources['storage'] = service_backend_config.storage
if job_config.storage != '0Gi':
resources['storage'] = job_config.storage

j = self._batch.create_jvm_job(
jar_spec=self.jar_spec,
Expand All @@ -343,7 +342,7 @@ async def _run_on_batch(
resources=resources,
attributes={'name': name + '_driver'},
regions=self.regions,
cloudfuse=[(c.bucket, c.mount_path, c.read_only) for c in service_backend_config.cloudfuse_configs],
cloudfuse=[(c.bucket, c.mount_path, c.read_only) for c in job_config.cloudfuse_configs],
profile=self.flags['profile'] is not None,
)
await self._batch.submit(disable_progress_bar=True)
Expand Down Expand Up @@ -441,8 +440,7 @@ async def _async_rpc(self, action: ActionTag, payload: ActionPayload):
return await self._run_on_batch(
name=f'{action.name.lower()}(...)',
service_backend_config=ServiceBackendRPCConfig(
tmp_dir=tmp_dir(),
remote_tmpdir=self.remote_tmpdir,
tmp_dir=self.remote_tmpdir,
flags=self.flags,
custom_references=[
orjson.dumps(rg._config).decode('utf-8')
Expand Down
11 changes: 7 additions & 4 deletions hail/src/main/scala/is/hail/backend/BackendRpc.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ package is.hail.backend

import is.hail.expr.ir.IRParser
import is.hail.expr.ir.functions.IRFunctionRegistry
import is.hail.expr.ir.functions.IRFunctionRegistry.UserDefinedFnKey
import is.hail.io.BufferSpec
import is.hail.io.plink.LoadPlink
import is.hail.io.vcf.LoadVCF
import is.hail.services.retryTransientErrors
import is.hail.types.virtual.{Kind, TFloat64, VType}
import is.hail.types.virtual.Kinds._
import is.hail.utils.{using, ExecutionTimer}
import is.hail.utils.{using, BoxedArrayBuilder, ExecutionTimer}
import is.hail.utils.ExecutionTimer.Timings
import is.hail.variant.ReferenceGenome

Expand Down Expand Up @@ -177,9 +178,10 @@ trait BackendRpc {
)(
body: => A
): A = {
val fns = new BoxedArrayBuilder[UserDefinedFnKey](serializedFunctions.length)
try {
serializedFunctions.foreach { func =>
IRFunctionRegistry.registerIR(
for (func <- serializedFunctions) {
fns += IRFunctionRegistry.registerIR(
ctx,
func.name,
func.type_parameters,
Expand All @@ -192,7 +194,8 @@ trait BackendRpc {

body
} finally
IRFunctionRegistry.clearUserFunctions()
for (i <- 0 until fns.length)
IRFunctionRegistry.unregisterIr(fns(i))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -308,17 +308,17 @@ class ServiceBackend(
ExecutionTimer.time { timer =>
ExecuteContext.scoped(
rpcConfig.tmp_dir,
rpcConfig.remote_tmpdir,
rpcConfig.tmp_dir,
this,
fs,
timer,
null,
theHailClassLoader,
flags,
ServiceBackendContext(
rpcConfig.remote_tmpdir,
rpcConfig.tmp_dir,
jobConfig,
ExecutionCache.fromFlags(flags, fs, rpcConfig.remote_tmpdir),
ExecutionCache.fromFlags(flags, fs, rpcConfig.tmp_dir),
),
new IrMetadata(),
references,
Expand Down Expand Up @@ -536,7 +536,6 @@ case class SequenceConfig(fasta: String, index: String)

case class ServiceBackendRPCPayload(
tmp_dir: String,
remote_tmpdir: String,
flags: Map[String, String],
custom_references: Array[String],
liftovers: Map[String, Map[String, String]],
Expand Down
58 changes: 43 additions & 15 deletions hail/src/main/scala/is/hail/expr/ir/functions/Functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ import scala.reflect._
import org.apache.spark.sql.Row

object IRFunctionRegistry {
private val userAddedFunctions: mutable.Set[(String, (Type, Seq[Type], Seq[Type]))] =
type UserDefinedFnKey = (String, (Type, Seq[Type], Seq[Type]))

private[this] val userAddedFunctions: mutable.Set[UserDefinedFnKey] =
mutable.HashSet.empty

def clearUserFunctions(): Unit = {
Expand Down Expand Up @@ -69,25 +71,41 @@ object IRFunctionRegistry {
typeParamStrs: Array[String],
argNameStrs: Array[String],
argTypeStrs: Array[String],
returnType: String,
returnTypeStr: String,
bodyStr: String,
): Unit = {
): UserDefinedFnKey = {
requireJavaIdentifier(name)
val argNames = argNameStrs.map(Name)
val typeParameters = typeParamStrs.map(IRParser.parseType).toFastSeq
val valueParameterTypes = argTypeStrs.map(IRParser.parseType).toFastSeq
val refMap = BindingEnv.eval(argNames.zip(valueParameterTypes): _*)
val body = IRParser.parse_value_ir(ctx, bodyStr, refMap)
val argNames = argNameStrs.map(Name)

val body =
IRParser.parse_value_ir(ctx, bodyStr, BindingEnv.eval(argNames.zip(valueParameterTypes): _*))
val returnType = IRParser.parseType(returnTypeStr)
assert(body.typ == returnType)

userAddedFunctions += ((name, (body.typ, typeParameters, valueParameterTypes)))
val key: UserDefinedFnKey = (name, (returnType, typeParameters, valueParameterTypes))
userAddedFunctions += key
addIR(
name,
typeParameters,
valueParameterTypes,
IRParser.parseType(returnType),
returnType,
false,
(_, args, _) => Subst(body, BindingEnv.eval(argNames.zip(args): _*)),
)
key
}

def unregisterIr(key: UserDefinedFnKey): Unit = {
val (name, (returnType, typeParameterTypes, valueParameterTypes)) = key
if (userAddedFunctions.remove(key))
removeIRFunction(name, returnType, typeParameterTypes, valueParameterTypes)
else {
throw new NoSuchElementException(
s"No user defined function registered matching: ${prettyFunctionSignature(name, returnType, typeParameterTypes, valueParameterTypes)}"
)
}
}

def removeIRFunction(
Expand All @@ -112,7 +130,9 @@ object IRFunctionRegistry {
case Seq() => None
case Seq(f) => Some(f)
case _ =>
fatal(s"Multiple functions found that satisfy $name(${valueParameterTypes.mkString(",")}).")
fatal(
s"Multiple functions found that satisfy ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)}."
)
}

def lookupFunctionOrFail(
Expand All @@ -124,28 +144,34 @@ object IRFunctionRegistry {
jvmRegistry.lift(name) match {
case None =>
fatal(
s"no functions found with the signature $name(${valueParameterTypes.mkString(", ")}): $returnType"
s"no functions found with the signature ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)}."
)
case Some(functions) =>
functions.filter(t =>
t.unify(typeParameters, valueParameterTypes, returnType)
).toSeq match {
case Seq() =>
val prettyFunctionSignature =
s"$name[${typeParameters.mkString(", ")}](${valueParameterTypes.mkString(", ")}): $returnType"
val prettyMismatchedFunctionSignatures = functions.map(x => s" $x").mkString("\n")
fatal(
s"No function found with the signature $prettyFunctionSignature.\n" +
s"No function found with the signature ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)}.\n" +
s"However, there are other functions with that name:\n$prettyMismatchedFunctionSignatures"
)
case Seq(f) => f
case _ => fatal(
s"Multiple functions found that satisfy $name(${valueParameterTypes.mkString(", ")})."
s"Multiple functions found that satisfy ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)})."
)
}
}
}

private[this] def prettyFunctionSignature(
name: String,
returnType: Type,
typeParameterTypes: Seq[Type],
valueParameterTypes: Seq[Type],
): String =
s"$name[${typeParameterTypes.mkString(", ")}](${valueParameterTypes.mkString(", ")}): $returnType"

def lookupIR(
name: String,
returnType: Type,
Expand All @@ -165,7 +191,9 @@ object IRFunctionRegistry {
case Seq() => None
case Seq(kv) => Some(kv)
case _ =>
fatal(s"Multiple functions found that satisfy $name(${valueParameterTypes.mkString(",")}).")
fatal(
s"Multiple functions found that satisfy ${prettyFunctionSignature(name, returnType, typeParameters, valueParameterTypes)}."
)
}
}

Expand Down

0 comments on commit 59ad7bb

Please sign in to comment.