From 3d22ae9101dc4d30fb5d422ffe18140b4a86deb3 Mon Sep 17 00:00:00 2001 From: patrick-schultz Date: Tue, 17 Dec 2024 11:16:26 -0500 Subject: [PATCH] ir gen mvp --- hail/.scalafmt.conf | 2 +- hail/build.mill | 17 + hail/hail/ir-gen/src/Main.scala | 899 +++++++++++ hail/hail/src/is/hail/expr/ir/IR.scala | 1370 +++++------------ hail/hail/src/is/hail/expr/ir/Parser.scala | 2 +- hail/hail/src/is/hail/expr/ir/TypeCheck.scala | 2 +- .../expr/ir/functions/ArrayFunctions.scala | 4 +- .../is/hail/expr/ir/functions/Functions.scala | 20 +- .../is/hail/expr/ir/FoldConstantsSuite.scala | 8 +- .../ir/{ => defs}/EncodedLiteralSuite.scala | 3 +- 10 files changed, 1345 insertions(+), 982 deletions(-) create mode 100644 hail/hail/ir-gen/src/Main.scala rename hail/hail/test/src/is/hail/expr/ir/{ => defs}/EncodedLiteralSuite.scala (88%) diff --git a/hail/.scalafmt.conf b/hail/.scalafmt.conf index 245e3f6548b..7d38a8da4fc 100644 --- a/hail/.scalafmt.conf +++ b/hail/.scalafmt.conf @@ -1,4 +1,4 @@ -version = 3.7.17 +version = 3.8.3 runner.dialect = scala212 diff --git a/hail/build.mill b/hail/build.mill index 6a14da4d4b3..324e6e591f8 100644 --- a/hail/build.mill +++ b/hail/build.mill @@ -175,6 +175,10 @@ object hail extends HailModule { outer => buildInfo(), ) + override def generatedSources: T[Seq[PathRef]] = Task { + Seq(`ir-gen`.generate()) + } + override def unmanagedClasspath: T[Agg[PathRef]] = Agg(shadedazure.assembly()) @@ -246,6 +250,19 @@ object hail extends HailModule { outer => PathRef(T.dest) } + object `ir-gen` extends HailModule { + def ivyDeps = Agg( + ivy"com.lihaoyi::mainargs:0.6.2", + ivy"com.lihaoyi::os-lib:0.10.7", + ivy"com.lihaoyi::sourcecode:0.4.2", + ) + + def generate: T[PathRef] = Task { + runner().run(Args("--path", T.dest).value) + PathRef(T.dest) + } + } + object memory extends JavaModule { // with CrossValue { override def zincIncrementalCompilation: T[Boolean] = false diff --git a/hail/hail/ir-gen/src/Main.scala b/hail/hail/ir-gen/src/Main.scala new file mode 100644 index 00000000000..ff143a0d96d --- /dev/null +++ b/hail/hail/ir-gen/src/Main.scala @@ -0,0 +1,899 @@ +import mainargs.{ParserForMethods, main} + +sealed abstract class Trait(val name: String) + +object Trivial extends Trait("TrivialIR") +object BaseRef extends Trait("BaseRef") +final case class Typed(typ: String) extends Trait(s"TypedIR[$typ]") +object NDArray extends Trait("NDArrayIR") + +final case class Apply(missingnessAware: Boolean = false) + extends Trait( + s"AbstractApplyNode[UnseededMissingness${if (missingnessAware) "Aware" else "Oblivious"}JVMFunction]" + ) + +case class NChildren(static: Int = 0, dynamic: String = "") { + def +(other: NChildren): NChildren = NChildren( + static = static + other.static, + dynamic = if (dynamic.isEmpty) other.dynamic else s"$dynamic + ${other.dynamic}", + ) +} + +sealed abstract class ChildrenSeq { + def asDyn: ChildrenSeq.Dynamic + + def ++(other: ChildrenSeq): ChildrenSeq = (this, other) match { + case (ChildrenSeq.Static(Seq()), r) => r + case (l, ChildrenSeq.Static(Seq())) => l + case (ChildrenSeq.Static(l), ChildrenSeq.Static(r)) => ChildrenSeq.Static(l ++ r) + case _ => ChildrenSeq.Dynamic(this.asDyn.children + " ++ " + other.asDyn.children) + } +} + +object ChildrenSeq { + val empty: Static = Static(Seq.empty) + + final case class Static(children: Seq[String]) extends ChildrenSeq { + def asDyn: Dynamic = Dynamic(s"FastSeq(${children.mkString(", ")})") + + override def toString: String = ??? + } + + final case class Dynamic(children: String) extends ChildrenSeq { + def asDyn: Dynamic = this + } +} + +final case class NamedAttOrChildPack( + name: String, + pack: AttOrChildPack, + isVar: Boolean = false, + default: Option[String] = None, +) { + def mutable: NamedAttOrChildPack = copy(isVar = true) + def withDefault(value: String): NamedAttOrChildPack = copy(default = Some(value)) + + def generateDeclaration: String = + s"${if (isVar) "var " else ""}$name: ${pack.generateDeclaration}${default.map(d => s" = $d").getOrElse("")}" + + def constraints: Seq[String] = pack.constraints.map(_(name)) + def nChildren: NChildren = pack.nChildren(name) + def childrenSeq: ChildrenSeq = pack.childrenSeq(name) +} + +sealed abstract class AttOrChildPack { + def * : AttOrChildPack = Collection(this) + def + : AttOrChildPack = Collection(this, allowEmpty = false) + def ? : AttOrChildPack = Optional(this) + + def generateDeclaration: String + def constraints: Seq[String => String] = Seq.empty + def nChildren: String => NChildren = _ => NChildren() + def childrenSeq: String => ChildrenSeq +} + +final case class Att(typ: String, override val constraints: Seq[String => String] = Seq.empty) + extends AttOrChildPack { + def withConstraint(c: String => String): Att = copy(constraints = constraints :+ c) + + override def generateDeclaration: String = typ + override def childrenSeq: String => ChildrenSeq = _ => ChildrenSeq.empty +} + +final case class Child(t: String = "IR") extends AttOrChildPack { + override def generateDeclaration: String = t + override def nChildren: String => NChildren = _ => NChildren(static = 1) + override def childrenSeq: String => ChildrenSeq = x => ChildrenSeq.Static(Seq(x)) +} + +case object Binding extends AttOrChildPack { + override def generateDeclaration: String = "Binding" + override def nChildren: String => NChildren = _ => NChildren(static = 1) + override def childrenSeq: String => ChildrenSeq = x => ChildrenSeq.Static(Seq(s"$x.value")) +} + +final case class Optional(elt: AttOrChildPack) extends AttOrChildPack { + override def generateDeclaration: String = s"Option[${elt.generateDeclaration}]" + + override def nChildren: String => NChildren = elt match { + case Att(_, _) => _ => NChildren() + case _ => opt => NChildren(dynamic = s"$opt.map(x => ${elt.nChildren("x")}).sum") + } + + override def childrenSeq: String => ChildrenSeq = elt match { + case Att(_, _) => _ => ChildrenSeq.empty + case Child(_) => opt => ChildrenSeq.Dynamic(s"$opt.toSeq") + case _ => opt => + ChildrenSeq.Dynamic(s"$opt.toSeq.flatMap(x => ${elt.childrenSeq("x").asDyn.children})") + } +} + +final case class Collection(elt: AttOrChildPack, allowEmpty: Boolean = true) + extends AttOrChildPack { + override def generateDeclaration: String = s"IndexedSeq[${elt.generateDeclaration}]" + + override def constraints: Seq[String => String] = { + val nestedConstraints: Seq[String => String] = if (elt.constraints.nonEmpty) + Seq(elts => + s"$elts.forall(x => ${elt.constraints.map(c => s"(${c("x")})").mkString(" && ")})" + ) + else Seq() + val nonEmptyConstraint: Seq[String => String] = + if (allowEmpty) Seq() else Seq(x => s"$x.nonEmpty") + nestedConstraints ++ nonEmptyConstraint + } + + override def nChildren: String => NChildren = elt match { + case Att(_, _) => _ => NChildren() + case Child(_) => elts => NChildren(dynamic = s"$elts.size") + case _ => elts => NChildren(dynamic = s"$elts.map(x => ${elt.nChildren("x")}).sum") + } + + override def childrenSeq: String => ChildrenSeq = elt match { + case Att(_, _) => _ => ChildrenSeq.empty + case Child(_) => elts => ChildrenSeq.Dynamic(elts) + case _ => + elts => ChildrenSeq.Dynamic(s"$elts.flatMap(x => ${elt.childrenSeq("x").asDyn.children})") + } +} + +final case class Tup(elts: AttOrChildPack*) extends AttOrChildPack { + override def generateDeclaration: String = + elts.map(_.generateDeclaration).mkString("(", ", ", ")") + + override def constraints: Seq[String => String] = + elts.zipWithIndex.flatMap { case (elt, i) => + elt.constraints.map(c => (t: String) => c(s"$t._${i + 1}")) + } + + override def nChildren: String => NChildren = t => + elts.zipWithIndex + .map { case (elt, i) => elt.nChildren(s"$t._${i + 1}") } + .foldLeft(NChildren())(_ + _) + + override def childrenSeq: String => ChildrenSeq = + t => + elts + .zipWithIndex.map { case (elt, i) => elt.childrenSeq(s"$t._${i + 1}") } + .foldLeft[ChildrenSeq](ChildrenSeq.empty)(_ ++ _) +} + +case class IR( + name: String, + attsAndChildren: Seq[NamedAttOrChildPack], + traits: Seq[Trait] = Seq.empty, + constraints: Seq[String] = Seq.empty, + extraMethods: Seq[String] = Seq.empty, + applyMethods: Seq[String] = Seq.empty, + docstring: String = "", +) { + def withTraits(newTraits: Trait*): IR = copy(traits = traits ++ newTraits) + def withMethod(methodDef: String): IR = copy(extraMethods = extraMethods :+ methodDef) + def withApply(methodDef: String): IR = copy(applyMethods = applyMethods :+ methodDef) + def typed(typ: String): IR = withTraits(Typed(typ)) + def withConstraint(c: String): IR = copy(constraints = constraints :+ c) + + def withCompanionExtension: IR = withApply( + s"implicit def toCompExt(comp: $name.type): ${name}Ext.type = ${name}Ext" + ) + + def withClassExtension: IR = withApply( + s"implicit def toExt(obj: $name): ${name}Ext = new ${name}Ext(obj)" + ) + + def withDocstring(docstring: String): IR = copy(docstring = docstring) + + private def nChildren: NChildren = attsAndChildren.foldLeft(NChildren())(_ + _.nChildren) + + private def children: String = + attsAndChildren.foldLeft[ChildrenSeq](ChildrenSeq.empty)(_ ++ _.childrenSeq).asDyn.children + + private def paramList = s"$name(${attsAndChildren.map(_.generateDeclaration).mkString(", ")})" + + private def classDecl = + s"final case class $paramList extends IR" + traits.map(" with " + _.name).mkString + + private def classBody = { + val extraMethods = + this.extraMethods :+ s"override lazy val childrenSeq: IndexedSeq[BaseIR] = $children" + val constraints = this.constraints ++ attsAndChildren.flatMap(_.constraints) + if (constraints.nonEmpty || extraMethods.nonEmpty) { + ( + " {" + + (if (constraints.nonEmpty) + constraints.map(c => s" require($c)").mkString("\n", "\n", "\n") + else "") + + ( + if (extraMethods.nonEmpty) + extraMethods.map(" " + _).mkString("\n", "\n", "\n") + else "" + ) + + "}" + ) + } else "" + } + + private def classDef = + (if (docstring.nonEmpty) s"\n/** $docstring*/\n" else "") + classDecl + classBody + + private def companionBody = applyMethods.map(" " + _).mkString("\n") + + private def companionDef = + if (companionBody.isEmpty) "" else s"object $name {\n$companionBody\n}\n" + + def generateDef: String = companionDef + classDef + "\n" +} + +object Main { + private def node(name: String, attsAndChildren: NamedAttOrChildPack*): IR = + IR(name, attsAndChildren) + + implicit private def makeNamedPack(tup: (String, AttOrChildPack)): NamedAttOrChildPack = + NamedAttOrChildPack(tup._1, tup._2) + + private def att(typ: String): Att = Att(typ) + + private val binding = Binding + + private def tup(elts: AttOrChildPack*): Tup = Tup(elts: _*) + + private def child = Child() + + abstract private class BaseImplicit(t: AttOrChildPack, defaultName: String) { + implicit def makeNamedChild(tup: (String, this.type)): NamedAttOrChildPack = + NamedAttOrChildPack(tup._1, t) + + implicit def makeDefaultNamedChild(x: this.type): NamedAttOrChildPack = + makeNamedChild((defaultName, x)) + + implicit def makeChild(x: this.type): AttOrChildPack = t + } + + private object name extends BaseImplicit(Att("Name"), "name") + private object key extends BaseImplicit(Att("IndexedSeq[String]"), "key") + private object tableChild extends BaseImplicit(Child("TableIR"), "child") + private object matrixChild extends BaseImplicit(Child("MatrixIR"), "child") + private object blockMatrixChild extends BaseImplicit(Child("BlockMatrixIR"), "child") + + private val errorID = ("errorID", att("Int")).withDefault("ErrorIDs.NO_ERROR") + + private def _typ(t: String = "Type") = ("_typ", att(t)) + + private val mmPerElt = ("requiresMemoryManagementPerElement", att("Boolean")).withDefault("false") + + private def allNodes: Seq[IR] = { + // scalafmt: {} + + val r = Seq.newBuilder[IR] + + r += node("I32", ("x", att("Int"))).withTraits(Trivial) + r += node("I64", ("x", att("Long"))).withTraits(Trivial) + r += node("F32", ("x", att("Float"))).withTraits(Trivial) + r += node("F64", ("x", att("Double"))).withTraits(Trivial) + r += node("Str", ("x", att("String"))).withTraits(Trivial) + .withMethod( + "override def toString(): String = s\"\"\"Str(\"${StringEscapeUtils.escapeString(x)}\")\"\"\"" + ) + r += node("True").withTraits(Trivial) + r += node("False").withTraits(Trivial) + r += node("Void").withTraits(Trivial) + r += node("NA", _typ()).withTraits(Trivial) + r += node("UUID4", ("id", att("String"))) + .withDocstring( + """WARNING! This node can only be used when trying to append a one-off, + |random string that will not be reused elsewhere in the pipeline. + |Any other uses will need to write and then read again; this node is non-deterministic + |and will not e.g. exhibit the correct semantics when self-joining on streams. + |""".stripMargin + ) + .withCompanionExtension + + r += node( + "Literal", + ("_typ", att("Type").withConstraint(self => s"!CanEmit($self)")), + ("value", att("Annotation").withConstraint(self => s"$self != null")), + ) + .withCompanionExtension + .withMethod( + """// expensive, for debugging + |// require(SafeRow.isSafe(value)) + |// assert(_typ.typeCheck(value), s"literal invalid:\n ${_typ}\n $value") + |""".stripMargin + ) + + r += node( + "EncodedLiteral", + ( + "codec", + att("AbstractTypedCodecSpec").withConstraint(self => + s"!CanEmit($self.encodedVirtualType)" + ), + ), + ("value", att("WrappedByteArrays").withConstraint(self => s"$self != null")), + ) + .withCompanionExtension + + r += node("Cast", ("v", child), _typ()) + r += node("CastRename", ("v", child), _typ()) + + r += node("IsNA", ("value", child)) + r += node("Coalesce", ("values", child.+)) + r += node("Consume", ("value", child)) + + r += node("If", ("cond", child), ("cnsq", child), ("altr", child)) + r += node("Switch", ("x", child), ("default", child), ("cases", child.*)) + .withMethod("override lazy val size: Int = 2 + cases.length") + + r += node("Block", ("bindings", binding.*), ("body", child)) + .withMethod("override lazy val size: Int = bindings.length + 1") + .withCompanionExtension + + r += node("Ref", name, _typ().mutable).withTraits(BaseRef) + + r += node( + "TailLoop", + name, + ("params", tup(name, child).*), + ("resultType", att("Type")), + ("body", child), + ) + .withDocstring( + """Recur can't exist outside of loop. Loops can be nested, but we can't call outer + |loops in terms of inner loops so there can only be one loop "active" in a given + |context. + |""".stripMargin + ) + .withMethod("lazy val paramIdx: Map[Name, Int] = params.map(_._1).zipWithIndex.toMap") + r += node("Recur", name, ("args", child.*), _typ().mutable).withTraits(BaseRef) + + r += node("RelationalLet", name, ("value", child), ("body", child)) + r += node("RelationalRef", name, _typ()).withTraits(BaseRef) + + r += node("ApplyBinaryPrimOp", ("op", att("BinaryOp")), ("l", child), ("r", child)) + r += node("ApplyUnaryPrimOp", ("op", att("UnaryOp")), ("x", child)) + r += node( + "ApplyComparisonOp", + ("op", att("ComparisonOp[_]")).mutable, + ("l", child), + ("r", child), + ) + + r += node("MakeArray", ("args", child.*), _typ("TArray")).withCompanionExtension + r += node("MakeStream", ("args", child.*), _typ("TStream"), mmPerElt).withCompanionExtension + r += node("ArrayRef", ("a", child), ("i", child), errorID) + r += node( + "ArraySlice", + ("a", child), + ("start", child), + ("stop", child.?), + ("step", child).withDefault("I32(1)"), + errorID, + ) + r += node("ArrayLen", ("a", child)) + r += node("ArrayZeros", ("length", child)) + r += node( + "ArrayMaximalIndependentSet", + ("edges", child), + ("tieBreaker", tup(name, name, child).?), + ) + + r += node("StreamIota", ("start", child), ("step", child), mmPerElt) + .withDocstring( + """[[StreamIota]] is an infinite stream producer, whose element is an integer starting at + |`start`, updated by `step` at each iteration. The name comes from APL: + |[[https://stackoverflow.com/questions/9244879/what-does-iota-of-stdiota-stand-for]] + |""".stripMargin + ) + r += node("StreamRange", ("start", child), ("stop", child), ("step", child), mmPerElt, errorID) + + r += node("ArraySort", ("a", child), ("left", name), ("right", name), ("lessThan", child)) + .withCompanionExtension + + r += node("ToSet", ("a", child)) + r += node("ToDict", ("a", child)) + r += node("ToArray", ("a", child)) + r += node("CastToArray", ("a", child)) + r += node("ToStream", ("a", child), mmPerElt) + r += node("GroupByKey", ("collection", child)) + + r += node( + "StreamBufferedAggregate", + ("streamChild", child), + ("initAggs", child), + ("newKey", child), + ("seqOps", child), + name, + ("aggSignature", att("IndexedSeq[PhysicalAggSig]")), + ("bufferSize", att("Int")), + ) + r += node( + "LowerBoundOnOrderedCollection", + ("orderedCollection", child), + ("elem", child), + ("onKey", att("Boolean")), + ) + + r += node("RNGStateLiteral") + r += node("RNGSplit", ("state", child), ("dynBitstring", child)) + + r += node("StreamLen", ("a", child)) + r += node("StreamGrouped", ("a", child), ("groupSize", child)) + r += node("StreamGroupByKey", ("a", child), key, ("missingEqual", att("Boolean"))) + r += node("StreamMap", ("a", child), name, ("body", child)).typed("TStream") + r += node("StreamTakeWhile", ("a", child), ("elementName", name), ("body", child)) + .typed("TStream") + r += node("StreamDropWhile", ("a", child), ("elementName", name), ("body", child)) + .typed("TStream") + r += node("StreamTake", ("a", child), ("num", child)).typed("TStream") + r += node("StreamDrop", ("a", child), ("num", child)).typed("TStream") + + r += node( + "SeqSample", + ("totalRange", child), + ("numToSample", child), + ("rngState", child), + mmPerElt, + ) + .typed("TStream") + .withDocstring( + """Generate, in ascending order, a uniform random sample, without replacement, of + |numToSample integers in the range [0, totalRange) + |""".stripMargin + ) + + r += node( + "StreamDistribute", + ("child", child), + ("pivots", child), + ("path", child), + ("comparisonOp", att("ComparisonOp[_]")), + ("spec", att("AbstractTypedCodecSpec")), + ) + .withDocstring( + """Take the child stream and sort each element into buckets based on the provided pivots. + |The first and last elements of pivots are the endpoints of the first and last interval + |respectively, should not be contained in the dataset. + |""".stripMargin + ) + + r += node( + "StreamWhiten", + ("stream", child), + ("newChunk", att("String")), + ("prevWindow", att("String")), + ("vecSize", att("Int")), + ("windowSize", att("Int")), + ("chunkSize", att("Int")), + ("blockSize", att("Int")), + ("normalizeAfterWhiten", att("Boolean")), + ) + .typed("TStream") + .withDocstring( + """"Whiten" a stream of vectors by regressing out from each vector all components + |in the direction of vectors in the preceding window. For efficiency, takes + |a stream of "chunks" of vectors. + |Takes a stream of structs, with two designated fields: `prevWindow` is the + |previous window (e.g. from the previous partition), if there is one, and + |`newChunk` is the new chunk to whiten. + |""".stripMargin + ) + + r += node( + "StreamZip", + ("as", child.*), + ("names", name.*), + ("body", child), + ("behavior", att("ArrayZipBehavior.ArrayZipBehavior")), + errorID, + ) + .typed("TStream") + + r += node("StreamMultiMerge", ("as", child.*), key).typed("TStream") + + r += node( + "StreamZipJoinProducers", + ("contexts", child), + ("ctxName", name), + ("makeProducer", child), + key, + ("curKey", name), + ("curVals", name), + ("joinF", child), + ) + .typed("TStream") + + r += node( + "StreamZipJoin", + ("as", child.*), + key, + ("curKey", name), + ("curVals", name), + ("joinF", child), + ) + .typed("TStream") + .withDocstring( + """The StreamZipJoin node assumes that input streams have distinct keys. If input streams do not + |have distinct keys, the key that is included in the result is undefined, but is likely the + |last. + |""".stripMargin + ) + + r += node("StreamFilter", ("a", child), name, ("cond", child)).typed("TStream") + r += node("StreamFlatMap", ("a", child), name, ("cond", child)).typed("TStream") + + r += node( + "StreamFold", + ("a", child), + ("zero", child), + ("accumName", name), + ("valueName", name), + ("body", child), + ) + + r += node( + "StreamFold2", + ("a", child), + ("accum", tup(name, child).*), + ("valueName", name), + ("seq", child.*), + ("result", child), + ) + .withConstraint("accum.length == seq.length") + .withMethod("val nameIdx: Map[Name, Int] = accum.map(_._1).zipWithIndex.toMap") + .withCompanionExtension + + r += node( + "StreamScan", + ("a", child), + ("zero", child), + ("accumName", name), + ("valueName", name), + ("body", child), + ) + .typed("TStream") + + r += node("StreamFor", ("a", child), ("valueName", name), ("body", child)).typed("TVoid.type") + r += node("StreamAgg", ("a", child), name, ("query", child)) + r += node("StreamAggScan", ("a", child), name, ("query", child)).typed("TStream") + + r += node( + "StreamLeftIntervalJoin", + ("left", child), + ("right", child), + ("lKeyFieldName", att("String")), + ("rIntervalFieldName", att("String")), + ("lname", name), + ("rname", name), + ("body", child), + ) + .typed("TStream") + + r += node( + "StreamJoinRightDistinct", + ("left", child), + ("right", child), + ("lKey", att("IndexedSeq[String]")), + ("rKey", att("IndexedSeq[String]")), + ("l", name), + ("r", name), + ("joinF", child), + ("joinType", att("String")), + ) + .typed("TStream").withClassExtension + + r += node( + "StreamLocalLDPrune", + ("child", child), + ("r2Threshold", child), + ("windowSize", child), + ("maxQueueSize", child), + ("nSamples", child), + ) + .typed("TStream") + + r += node("MakeNDArray", ("data", child), ("shape", child), ("rowMajor", child), errorID) + .withTraits(NDArray).withCompanionExtension + r += node("NDArrayShape", ("nd", child)) + r += node("NDArrayReshape", ("nd", child), ("shape", child), errorID).withTraits(NDArray) + r += node("NDArrayConcat", ("nds", child), ("axis", att("Int"))).withTraits(NDArray) + r += node("NDArrayRef", ("nd", child), ("idxs", child.*), errorID) + r += node("NDArraySlice", ("nd", child), ("slices", child)).withTraits(NDArray) + r += node("NDArrayFilter", ("nd", child), ("keep", child.*)).withTraits(NDArray) + r += node("NDArrayMap", ("nd", child), ("valueName", name), ("body", child)).withTraits(NDArray) + r += node( + "NDArrayMap2", + ("l", child), + ("r", child), + ("lName", name), + ("rName", name), + ("body", child), + errorID, + ) + .withTraits(NDArray) + r += node("NDArrayReindex", ("nd", child), ("indexExpr", att("IndexedSeq[Int]"))) + .withTraits(NDArray) + r += node("NDArrayAgg", ("nd", child), ("axes", att("IndexedSeq[Int]"))) + r += node("NDArrayWrite", ("nd", child), ("path", child)).typed("TVoid.type") + r += node("NDArrayMatMul", ("l", child), ("r", child), errorID).withTraits(NDArray) + r += node("NDArrayQR", ("nd", child), ("mode", att("String")), errorID).withCompanionExtension + r += node( + "NDArraySVD", + ("nd", child), + ("fullMatrices", att("Boolean")), + ("computeUV", att("Boolean")), + errorID, + ) + .withCompanionExtension + r += node("NDArrayEigh", ("nd", child), ("eigvalsOnly", att("Boolean")), errorID) + .withCompanionExtension + r += node("NDArrayInv", ("nd", child), errorID).withTraits(NDArray).withCompanionExtension + + val isScan = ("isScan", att("Boolean")) + + r += node("AggFilter", ("cond", child), ("aggIR", child), isScan) + r += node("AggExplode", ("array", child), name, ("aggBody", child), isScan) + r += node("AggGroupBy", ("key", child), ("aggIR", child), isScan) + r += node( + "AggArrayPerElement", + ("a", child), + ("elementName", name), + ("indexName", name), + ("aggBody", child), + ("knownLength", child.?), + isScan, + ) + r += node( + "AggFold", + ("zero", child), + ("seqOp", child), + ("combOp", child), + ("accumName", name), + ("otherAccumName", name), + isScan, + ) + .withCompanionExtension + + r += node( + "ApplyAggOp", + ("initOpArgs", child.*), + ("seqOpArgs", child.*), + ("aggSig", att("AggSignature")), + ) + .withClassExtension.withCompanionExtension + r += node( + "ApplyScanOp", + ("initOpArgs", child.*), + ("seqOpArgs", child.*), + ("aggSig", att("AggSignature")), + ) + .withClassExtension.withCompanionExtension + r += node("InitOp", ("i", att("Int")), ("args", child.*), ("aggSig", att("PhysicalAggSig"))) + r += node("SeqOp", ("i", att("Int")), ("args", child.*), ("aggSig", att("PhysicalAggSig"))) + r += node("CombOp", ("i1", att("Int")), ("i2", att("Int")), ("aggSig", att("PhysicalAggSig"))) + r += node("ResultOp", ("idx", att("Int")), ("aggSig", att("PhysicalAggSig"))) + .withCompanionExtension + r += node("CombOpValue", ("i", att("Int")), ("value", child), ("aggSig", att("PhysicalAggSig"))) + r += node("AggStateValue", ("i", att("Int")), ("aggSig", att("AggStateSig"))) + r += node( + "InitFromSerializedValue", + ("i", att("Int")), + ("value", child), + ("aggSig", att("AggStateSig")), + ) + r += node( + "SerializeAggs", + ("startIdx", att("Int")), + ("serializedIdx", att("Int")), + ("spec", att("BufferSpec")), + ("aggSigs", att("IndexedSeq[AggStateSig]")), + ) + r += node( + "DeserializeAggs", + ("startIdx", att("Int")), + ("serializedIdx", att("Int")), + ("spec", att("BufferSpec")), + ("aggSigs", att("IndexedSeq[AggStateSig]")), + ) + r += node( + "RunAgg", + ("body", child), + ("result", child), + ("signature", att("IndexedSeq[AggStateSig]")), + ) + r += node( + "RunAggScan", + ("array", child), + name, + ("init", child), + ("seqs", child), + ("result", child), + ("signature", att("IndexedSeq[AggStateSig]")), + ) + + r += node("Begin", ("xs", child.*)).withCompanionExtension + + r += node("MakeStruct", ("fields", tup(att("String"), child).*)).typed("TStruct") + r += node("SelectFields", ("old", child), ("fields", att("IndexedSeq[String]"))) + .typed("TStruct") + r += node( + "InsertFields", + ("old", child), + ("fields", tup(att("String"), child).*), + ("fieldOrder", att("Option[IndexedSeq[String]]")).withDefault("None"), + ) + .typed("TStruct") + r += node("GetField", ("o", child), ("name", att("String"))) + r += node("MakeTuple", ("fields", tup(att("Int"), child).*)) + .typed("TTuple").withCompanionExtension + r += node("GetTupleElement", ("o", child), ("idx", att("Int"))) + + r += node("In", ("i", att("Int")), ("_typ", att("EmitParamType"))) + .withDocstring("Function input").withCompanionExtension + + r += node("Die", ("message", child), ("_typ", att("Type")), errorID).withCompanionExtension + r += node("Trap", ("child", child)).withDocstring( + """The Trap node runs the `child` node with an exception handler. If the child throws a + |HailException (user exception), then we return the tuple ((msg, errorId), NA). If the child + |throws any other exception, we raise that exception. If the child does not throw, then we + |return the tuple (NA, child value). + |""".stripMargin + ) + r += node("ConsoleLog", ("message", child), ("result", child)) + + r += node( + "ApplyIR", + ("function", att("String")), + ("typeArgs", att("Seq[Type]")), + ("args", child.*), + ("returnType", att("Type")), + errorID, + ) + .withMethod( + """var conversion: (Seq[Type], IndexedSeq[IR], Int) => IR = _ + |var inline: Boolean = _ + | + |private lazy val refs = args.map(a => Ref(freshName(), a.typ)).toArray + |lazy val body: IR = conversion(typeArgs, refs, errorID).deepCopy() + |lazy val refIdx: Map[Name, Int] = refs.map(_.name).zipWithIndex.toMap + | + |lazy val explicitNode: IR = { + | val ir = Let(refs.map(_.name).zip(args), body) + | assert(ir.typ == returnType) + | ir + |} + |""".stripMargin + ) + + r += node( + "Apply", + ("function", att("String")), + ("typeArgs", att("Seq[Type]")), + ("args", child.*), + ("returnType", att("Type")), + errorID, + ).withTraits(Apply()) + + r += node( + "ApplySeeded", + ("function", att("String")), + ("_args", child.*), + ("rngState", child), + ("staticUID", att("Long")), + ("returnType", att("Type")), + ).withTraits(Apply()) + .withMethod("val args = rngState +: _args") + .withMethod("val typeArgs: Seq[Type] = Seq.empty[Type]") + + r += node( + "ApplySpecial", + ("function", att("String")), + ("typeArgs", att("Seq[Type]")), + ("args", child.*), + ("returnType", att("Type")), + errorID, + ).withTraits(Apply(missingnessAware = true)) + + r += node("LiftMeOut", ("child", child)) + + r += node("TableCount", tableChild) + r += node("MatrixCount", matrixChild) + r += node("TableAggregate", tableChild, ("query", child)) + r += node("MatrixAggregate", matrixChild, ("query", child)) + r += node("TableWrite", tableChild, ("writer", att("TableWriter"))) + r += node( + "TableMultiWrite", + ("_children", tableChild.*), + ("writer", att("WrappedMatrixNativeMultiWriter")), + ) + r += node("TableGetGlobals", tableChild) + r += node("TableCollect", tableChild) + r += node("MatrixWrite", matrixChild, ("writer", att("MatrixWriter"))) + r += node( + "MatrixMultiWrite", + ("_children", matrixChild.*), + ("writer", att("MatrixNativeMultiWriter")), + ) + r += node("TableToValueApply", tableChild, ("function", att("TableToValueFunction"))) + r += node("MatrixToValueApply", matrixChild, ("function", att("MatrixToValueFunction"))) + r += node( + "BlockMatrixToValueApply", + blockMatrixChild, + ("function", att("BlockMatrixToValueFunction")), + ) + r += node("BlockMatrixCollect", blockMatrixChild) + r += node("BlockMatrixWrite", blockMatrixChild, ("writer", att("BlockMatrixWriter"))) + r += node( + "BlockMatrixMultiWrite", + ("blockMatrices", blockMatrixChild.*), + ("writer", att("BlockMatrixMultiWriter")), + ) + + r += node( + "CollectDistributedArray", + ("contexts", child), + ("globals", child), + ("cname", name), + ("gname", name), + ("body", child), + ("dynamicID", child), + ("staticID", att("String")), + ("tsd", att("Option[TableStageDependency]")).withDefault("None"), + ) + + r += node( + "ReadPartition", + ("context", child), + ("rowType", att("TStruct")), + ("reader", att("PartitionReader")), + ) + r += node( + "WritePartition", + ("value", child), + ("writeCtx", child), + ("writer", att("PartitionWriter")), + ) + r += node("WriteMetadata", ("writeAnnotations", child), ("writer", att("MetadataWriter"))) + r += node( + "ReadValue", + ("path", child), + ("reader", att("ValueReader")), + ("requestedType", att("Type")), + ) + r += node( + "WriteValue", + ("value", child), + ("path", child), + ("writer", att("ValueWriter")), + ("stagingFile", att("Option[IR]")).withDefault("None"), + ) + + r.result() + } + + @main + def main(path: String) = { + val pack = "package is.hail.expr.ir.defs" + val imports = Seq( + "is.hail.annotations.Annotation", + "is.hail.io.{AbstractTypedCodecSpec, BufferSpec}", + "is.hail.types.virtual.{Type, TArray, TStream, TVoid, TStruct, TTuple}", + "is.hail.utils.{FastSeq, StringEscapeUtils}", + "is.hail.expr.ir.{BaseIR, IR, TableIR, MatrixIR, BlockMatrixIR, Name, UnaryOp, BinaryOp, " + + "ComparisonOp, CanEmit, AggSignature, EmitParamType, freshName, TableWriter, " + + "WrappedMatrixNativeMultiWriter, MatrixWriter, MatrixNativeMultiWriter, BlockMatrixWriter, " + + "BlockMatrixMultiWriter, ValueReader, ValueWriter}", + "is.hail.expr.ir.lowering.TableStageDependency", + "is.hail.expr.ir.agg.{PhysicalAggSig, AggStateSig}", + "is.hail.expr.ir.functions.{UnseededMissingnessAwareJVMFunction, " + + "UnseededMissingnessObliviousJVMFunction, TableToValueFunction, MatrixToValueFunction, " + + "BlockMatrixToValueFunction}", + "is.hail.expr.ir.defs.exts._", + ) + val gen = pack + "\n\n" + imports.map(i => s"import $i").mkString("\n") + "\n\n" + allNodes.map( + _.generateDef + ).mkString("\n") + os.write(os.Path(path) / "IR_gen.scala", gen) + } + + def main(args: Array[String]): Unit = ParserForMethods(this).runOrExit(args) +} diff --git a/hail/hail/src/is/hail/expr/ir/IR.scala b/hail/hail/src/is/hail/expr/ir/IR.scala index 32dcf26f3f6..b649c93fad8 100644 --- a/hail/hail/src/is/hail/expr/ir/IR.scala +++ b/hail/hail/src/is/hail/expr/ir/IR.scala @@ -1,12 +1,11 @@ package is.hail.expr.ir -import is.hail.annotations.{Annotation, Region} +import is.hail.annotations.Region import is.hail.asm4s.Value import is.hail.backend.ExecuteContext -import is.hail.expr.ir.agg.{AggStateSig, PhysicalAggSig} +import is.hail.expr.ir.agg.PhysicalAggSig import is.hail.expr.ir.defs.ApplyIR import is.hail.expr.ir.functions._ -import is.hail.expr.ir.lowering.TableStageDependency import is.hail.expr.ir.streams.StreamProducer import is.hail.io.{AbstractTypedCodecSpec, BufferSpec, TypedCodecSpec} import is.hail.io.avro.{AvroPartitionReader, AvroSchemaSerializer} @@ -27,20 +26,22 @@ import java.io.OutputStream import org.json4s.{DefaultFormats, Extraction, Formats, JValue, ShortTypeHints} import org.json4s.JsonAST.{JNothing, JString} -sealed trait IR extends BaseIR { +trait IR extends BaseIR { private var _typ: Type = null - def typ: Type = { - if (_typ == null) + override def typ: Type = { + if (_typ == null) { try _typ = InferType(this) catch { case e: Throwable => throw new RuntimeException(s"typ: inference failure:", e) } + assert(_typ != null) + } _typ } - protected lazy val childrenSeq: IndexedSeq[BaseIR] = + override protected lazy val childrenSeq: IndexedSeq[BaseIR] = Children(this) override protected def copyWithNewChildren(newChildren: IndexedSeq[BaseIR]): IR = @@ -74,65 +75,12 @@ sealed trait IR extends BaseIR { package defs { - import is.hail.expr.ir.defs.ArrayZipBehavior.ArrayZipBehavior - - sealed trait TypedIR[T <: Type] extends IR { + trait TypedIR[T <: Type] extends IR { override def typ: T = tcoerce[T](super.typ) } // Mark Refs and constants as IRs that are safe to duplicate - sealed trait TrivialIR extends IR - - object Literal { - def coerce(t: Type, x: Any): IR = { - if (x == null) - return NA(t) - t match { - case TInt32 => I32(x.asInstanceOf[Number].intValue()) - case TInt64 => I64(x.asInstanceOf[Number].longValue()) - case TFloat32 => F32(x.asInstanceOf[Number].floatValue()) - case TFloat64 => F64(x.asInstanceOf[Number].doubleValue()) - case TBoolean => if (x.asInstanceOf[Boolean]) True() else False() - case TString => Str(x.asInstanceOf[String]) - case _ => Literal(t, x) - } - } - } - - final case class Literal(_typ: Type, value: Annotation) extends IR { - require(!CanEmit(_typ)) - require(value != null) - // expensive, for debugging - // require(SafeRow.isSafe(value)) - // assert(_typ.typeCheck(value), s"literal invalid:\n ${_typ}\n $value") - } - - object EncodedLiteral { - def apply(codec: AbstractTypedCodecSpec, value: Array[Array[Byte]]): EncodedLiteral = - EncodedLiteral(codec, new WrappedByteArrays(value)) - - def fromPTypeAndAddress(pt: PType, addr: Long, ctx: ExecuteContext): IR = { - pt match { - case _: PInt32 => I32(Region.loadInt(addr)) - case _: PInt64 => I64(Region.loadLong(addr)) - case _: PFloat32 => F32(Region.loadFloat(addr)) - case _: PFloat64 => F64(Region.loadDouble(addr)) - case _: PBoolean => if (Region.loadBoolean(addr)) True() else False() - case ts: PString => Str(ts.loadString(addr)) - case _ => - val etype = EType.defaultFromPType(ctx, pt) - val codec = TypedCodecSpec(etype, pt.virtualType, BufferSpec.wireSpec) - val bytes = codec.encodeArrays(ctx, pt, addr) - EncodedLiteral(codec, bytes) - } - } - } - - final case class EncodedLiteral(codec: AbstractTypedCodecSpec, value: WrappedByteArrays) - extends IR { - require(!CanEmit(codec.encodedVirtualType)) - require(value != null) - } + trait TrivialIR extends IR class WrappedByteArrays(val ba: Array[Array[Byte]]) { override def hashCode(): Int = @@ -150,49 +98,6 @@ package defs { } } - final case class I32(x: Int) extends IR with TrivialIR - final case class I64(x: Long) extends IR with TrivialIR - final case class F32(x: Float) extends IR with TrivialIR - final case class F64(x: Double) extends IR with TrivialIR - - final case class Str(x: String) extends IR with TrivialIR { - override def toString(): String = s"""Str("${StringEscapeUtils.escapeString(x)}")""" - } - - final case class True() extends IR with TrivialIR - final case class False() extends IR with TrivialIR - final case class Void() extends IR with TrivialIR - - object UUID4 { - def apply(): UUID4 = UUID4(genUID()) - } - -// WARNING! This node can only be used when trying to append a one-off, -// random string that will not be reused elsewhere in the pipeline. -// Any other uses will need to write and then read again; this node is -// non-deterministic and will not e.g. exhibit the correct semantics when -// self-joining on streams. - final case class UUID4(id: String) extends IR - - final case class Cast(v: IR, _typ: Type) extends IR - final case class CastRename(v: IR, _typ: Type) extends IR - - final case class NA(_typ: Type) extends IR with TrivialIR - final case class IsNA(value: IR) extends IR - - final case class Coalesce(values: Seq[IR]) extends IR { - require(values.nonEmpty) - } - - final case class Consume(value: IR) extends IR - - final case class If(cond: IR, cnsq: IR, altr: IR) extends IR - - final case class Switch(x: IR, default: IR, cases: IndexedSeq[IR]) extends IR { - override lazy val size: Int = - 2 + cases.length - } - object AggLet { def apply(name: Name, value: IR, body: IR, isScan: Boolean): IR = { val scope = if (isScan) Scope.SCAN else Scope.AGG @@ -215,284 +120,15 @@ package defs { Let(bindings.init, bindings.last._2) } } - } case class Binding(name: Name, value: IR, scope: Int = Scope.EVAL) - final case class Block(bindings: IndexedSeq[Binding], body: IR) extends IR { - override lazy val size: Int = - bindings.length + 1 - } - - object Block { - object Insert { - def unapply(bindings: IndexedSeq[Binding]) - : Option[(IndexedSeq[Binding], Binding, IndexedSeq[Binding])] = { - val idx = bindings.indexWhere(_.value.isInstanceOf[InsertFields]) - if (idx == -1) None else Some((bindings.take(idx), bindings(idx), bindings.drop(idx + 1))) - } - } - - object Nested { - def unapply(bindings: IndexedSeq[Binding]): Option[(Int, IndexedSeq[Binding])] = { - val idx = bindings.indexWhere(_.value.isInstanceOf[Block]) - if (idx == -1) None else Some((idx, bindings)) - } - } - } - - sealed abstract class BaseRef extends IR with TrivialIR { + trait BaseRef extends IR with TrivialIR { def name: Name def _typ: Type } - final case class Ref(name: Name, var _typ: Type) extends BaseRef { - override def typ: Type = { - assert(_typ != null) - _typ - } - } - -// Recur can't exist outside of loop -// Loops can be nested, but we can't call outer loops in terms of inner loops so there can only be one loop "active" in a given context - final case class TailLoop( - name: Name, - params: IndexedSeq[(Name, IR)], - resultType: Type, - body: IR, - ) extends IR { - lazy val paramIdx: Map[Name, Int] = params.map(_._1).zipWithIndex.toMap - } - - final case class Recur(name: Name, args: IndexedSeq[IR], var _typ: Type) extends BaseRef - - final case class RelationalLet(name: Name, value: IR, body: IR) extends IR - final case class RelationalRef(name: Name, _typ: Type) extends BaseRef - - final case class ApplyBinaryPrimOp(op: BinaryOp, l: IR, r: IR) extends IR - final case class ApplyUnaryPrimOp(op: UnaryOp, x: IR) extends IR - final case class ApplyComparisonOp(var op: ComparisonOp[_], l: IR, r: IR) extends IR - - object MakeArray { - def apply(args: IR*): MakeArray = { - assert(args.nonEmpty) - MakeArray(args.toArray, TArray(args.head.typ)) - } - - def unify(ctx: ExecuteContext, args: IndexedSeq[IR], requestedType: TArray = null) - : MakeArray = { - assert(requestedType != null || args.nonEmpty) - - if (args.nonEmpty) - if (args.forall(_.typ == args.head.typ)) - return MakeArray(args, TArray(args.head.typ)) - - MakeArray( - args.map { arg => - val upcast = PruneDeadFields.upcast(ctx, arg, requestedType.elementType) - assert(upcast.typ == requestedType.elementType) - upcast - }, - requestedType, - ) - } - } - - final case class MakeArray(args: IndexedSeq[IR], _typ: TArray) extends IR - - object MakeStream { - def unify( - ctx: ExecuteContext, - args: IndexedSeq[IR], - requiresMemoryManagementPerElement: Boolean = false, - requestedType: TStream = null, - ): MakeStream = { - assert(requestedType != null || args.nonEmpty) - - if (args.nonEmpty) - if (args.forall(_.typ == args.head.typ)) - return MakeStream(args, TStream(args.head.typ), requiresMemoryManagementPerElement) - - MakeStream( - args.map { arg => - val upcast = PruneDeadFields.upcast(ctx, arg, requestedType.elementType) - assert(upcast.typ == requestedType.elementType) - upcast - }, - requestedType, - requiresMemoryManagementPerElement, - ) - } - } - - final case class MakeStream( - args: IndexedSeq[IR], - _typ: TStream, - requiresMemoryManagementPerElement: Boolean = false, - ) extends IR - - object ArrayRef { - def apply(a: IR, i: IR): ArrayRef = ArrayRef(a, i, ErrorIDs.NO_ERROR) - } - - final case class ArrayRef(a: IR, i: IR, errorID: Int) extends IR - - final case class ArraySlice( - a: IR, - start: IR, - stop: Option[IR], - step: IR = I32(1), - errorID: Int = ErrorIDs.NO_ERROR, - ) extends IR - - final case class ArrayLen(a: IR) extends IR - - final case class ArrayZeros(length: IR) extends IR - - final case class ArrayMaximalIndependentSet(edges: IR, tieBreaker: Option[(Name, Name, IR)]) - extends IR - - /** [[StreamIota]] is an infinite stream producer, whose element is an integer starting at - * `start`, updated by `step` at each iteration. The name comes from APL: - * [[https://stackoverflow.com/questions/9244879/what-does-iota-of-stdiota-stand-for]] - */ - final case class StreamIota( - start: IR, - step: IR, - requiresMemoryManagementPerElement: Boolean = false, - ) extends IR - - final case class StreamRange( - start: IR, - stop: IR, - step: IR, - requiresMemoryManagementPerElement: Boolean = false, - errorID: Int = ErrorIDs.NO_ERROR, - ) extends IR - - object ArraySort { - def apply(a: IR, ascending: IR = True(), onKey: Boolean = false): ArraySort = { - val l = freshName() - val r = freshName() - val atyp = tcoerce[TStream](a.typ) - val compare = if (onKey) { - val elementType = atyp.elementType.asInstanceOf[TBaseStruct] - elementType match { - case _: TStruct => - val elt = tcoerce[TStruct](atyp.elementType) - ApplyComparisonOp( - Compare(elt.types(0)), - GetField(Ref(l, elt), elt.fieldNames(0)), - GetField(Ref(r, atyp.elementType), elt.fieldNames(0)), - ) - case _: TTuple => - val elt = tcoerce[TTuple](atyp.elementType) - ApplyComparisonOp( - Compare(elt.types(0)), - GetTupleElement(Ref(l, elt), elt.fields(0).index), - GetTupleElement(Ref(r, atyp.elementType), elt.fields(0).index), - ) - } - } else { - ApplyComparisonOp( - Compare(atyp.elementType), - Ref(l, atyp.elementType), - Ref(r, atyp.elementType), - ) - } - - ArraySort(a, l, r, If(ascending, compare < 0, compare > 0)) - } - } - - final case class ArraySort(a: IR, left: Name, right: Name, lessThan: IR) extends IR - - final case class ToSet(a: IR) extends IR - - final case class ToDict(a: IR) extends IR - - final case class ToArray(a: IR) extends IR - - final case class CastToArray(a: IR) extends IR - - final case class ToStream(a: IR, requiresMemoryManagementPerElement: Boolean = false) extends IR - - final case class StreamBufferedAggregate( - streamChild: IR, - initAggs: IR, - newKey: IR, - seqOps: IR, - name: Name, - aggSignatures: IndexedSeq[PhysicalAggSig], - bufferSize: Int, - ) extends IR - - final case class LowerBoundOnOrderedCollection(orderedCollection: IR, elem: IR, onKey: Boolean) - extends IR - - final case class GroupByKey(collection: IR) extends IR - - final case class RNGStateLiteral() extends IR - - final case class RNGSplit(state: IR, dynBitstring: IR) extends IR - - final case class StreamLen(a: IR) extends IR - - final case class StreamGrouped(a: IR, groupSize: IR) extends IR - - final case class StreamGroupByKey(a: IR, key: IndexedSeq[String], missingEqual: Boolean) - extends IR - - final case class StreamMap(a: IR, name: Name, body: IR) extends TypedIR[TStream] { - def elementTyp: Type = typ.elementType - } - - final case class StreamTakeWhile(a: IR, elementName: Name, body: IR) extends IR - - final case class StreamDropWhile(a: IR, elementName: Name, body: IR) extends IR - - final case class StreamTake(a: IR, num: IR) extends IR - - final case class StreamDrop(a: IR, num: IR) extends IR - - /* Generate, in ascending order, a uniform random sample, without replacement, of numToSample - * integers in the range [0, totalRange) */ - final case class SeqSample( - totalRange: IR, - numToSample: IR, - rngState: IR, - requiresMemoryManagementPerElement: Boolean, - ) extends IR - - /* Take the child stream and sort each element into buckets based on the provided pivots. The - * first and last elements of pivots are the endpoints of the first and last interval - * respectively, should not be contained in the dataset. */ - final case class StreamDistribute( - child: IR, - pivots: IR, - path: IR, - comparisonOp: ComparisonOp[_], - spec: AbstractTypedCodecSpec, - ) extends IR - - // "Whiten" a stream of vectors by regressing out from each vector all components - // in the direction of vectors in the preceding window. For efficiency, takes - // a stream of "chunks" of vectors. - // Takes a stream of structs, with two designated fields: `prevWindow` is the - // previous window (e.g. from the previous partition), if there is one, and - // `newChunk` is the new chunk to whiten. - final case class StreamWhiten( - stream: IR, - newChunk: String, - prevWindow: String, - vecSize: Int, - windowSize: Int, - chunkSize: Int, - blockSize: Int, - normalizeAfterWhiten: Boolean, - ) extends IR - object ArrayZipBehavior extends Enumeration { type ArrayZipBehavior = Value val AssumeSameLength: Value = Value(0) @@ -501,77 +137,6 @@ package defs { val ExtendNA: Value = Value(3) } - final case class StreamZip( - as: IndexedSeq[IR], - names: IndexedSeq[Name], - body: IR, - behavior: ArrayZipBehavior, - errorID: Int = ErrorIDs.NO_ERROR, - ) extends TypedIR[TStream] - - final case class StreamMultiMerge(as: IndexedSeq[IR], key: IndexedSeq[String]) - extends TypedIR[TStream] - - final case class StreamZipJoinProducers( - contexts: IR, - ctxName: Name, - makeProducer: IR, - key: IndexedSeq[String], - curKey: Name, - curVals: Name, - joinF: IR, - ) extends TypedIR[TStream] - - /** The StreamZipJoin node assumes that input streams have distinct keys. If input streams do not - * have distinct keys, the key that is included in the result is undefined, but is likely the - * last. - */ - final case class StreamZipJoin( - as: IndexedSeq[IR], - key: IndexedSeq[String], - curKey: Name, - curVals: Name, - joinF: IR, - ) extends TypedIR[TStream] - - final case class StreamFilter(a: IR, name: Name, cond: IR) extends TypedIR[TStream] - - final case class StreamFlatMap(a: IR, name: Name, body: IR) extends TypedIR[TStream] - - final case class StreamFold(a: IR, zero: IR, accumName: Name, valueName: Name, body: IR) - extends IR - - object StreamFold2 { - def apply(a: StreamFold): StreamFold2 = - StreamFold2( - a.a, - FastSeq((a.accumName, a.zero)), - a.valueName, - FastSeq(a.body), - Ref(a.accumName, a.zero.typ), - ) - } - - final case class StreamFold2( - a: IR, - accum: IndexedSeq[(Name, IR)], - valueName: Name, - seq: IndexedSeq[IR], - result: IR, - ) extends IR { - assert(accum.length == seq.length) - val nameIdx: Map[Name, Int] = accum.map(_._1).zipWithIndex.toMap - } - - final case class StreamScan(a: IR, zero: IR, accumName: Name, valueName: Name, body: IR) - extends IR - - final case class StreamFor(a: IR, valueName: Name, body: IR) extends IR - - final case class StreamAgg(a: IR, name: Name, query: IR) extends IR - - final case class StreamAggScan(a: IR, name: Name, query: IR) extends IR - object StreamJoin { def apply( left: IR, @@ -647,355 +212,10 @@ package defs { } } - final case class StreamLeftIntervalJoin( - // input streams - left: IR, - right: IR, - - // names for joiner - lKeyFieldName: String, - rIntervalFieldName: String, - - // how to combine records - lname: Name, - rname: Name, - body: IR, - ) extends IR { - override protected lazy val childrenSeq: IndexedSeq[BaseIR] = - FastSeq(left, right, body) - } - - final case class StreamJoinRightDistinct( - left: IR, - right: IR, - lKey: IndexedSeq[String], - rKey: IndexedSeq[String], - l: Name, - r: Name, - joinF: IR, - joinType: String, - ) extends IR { - def isIntervalJoin: Boolean = { - if (rKey.size != 1) return false - val lKeyTyp = tcoerce[TStruct](tcoerce[TStream](left.typ).elementType).fieldType(lKey(0)) - val rKeyTyp = tcoerce[TStruct](tcoerce[TStream](right.typ).elementType).fieldType(rKey(0)) - - rKeyTyp.isInstanceOf[TInterval] && lKeyTyp != rKeyTyp - } - } - - final case class StreamLocalLDPrune( - child: IR, - r2Threshold: IR, - windowSize: IR, - maxQueueSize: IR, - nSamples: IR, - ) extends IR - - sealed trait NDArrayIR extends TypedIR[TNDArray] { + trait NDArrayIR extends TypedIR[TNDArray] { def elementTyp: Type = typ.elementType } - object MakeNDArray { - def fill(elt: IR, shape: IndexedSeq[IR], rowMajor: IR): MakeNDArray = { - val flatSize: IR = if (shape.nonEmpty) - shape.reduce((l, r) => l * r) - else - 0L - MakeNDArray( - ToArray(mapIR(rangeIR(flatSize.toI))(_ => elt)), - MakeTuple.ordered(shape), - rowMajor, - ErrorIDs.NO_ERROR, - ) - } - } - - final case class MakeNDArray(data: IR, shape: IR, rowMajor: IR, errorId: Int) extends NDArrayIR - - final case class NDArrayShape(nd: IR) extends IR - - final case class NDArrayReshape(nd: IR, shape: IR, errorID: Int) extends NDArrayIR - - final case class NDArrayConcat(nds: IR, axis: Int) extends NDArrayIR - - final case class NDArrayRef(nd: IR, idxs: IndexedSeq[IR], errorId: Int) extends IR - - final case class NDArraySlice(nd: IR, slices: IR) extends NDArrayIR - - final case class NDArrayFilter(nd: IR, keep: IndexedSeq[IR]) extends NDArrayIR - - final case class NDArrayMap(nd: IR, valueName: Name, body: IR) extends NDArrayIR - - final case class NDArrayMap2(l: IR, r: IR, lName: Name, rName: Name, body: IR, errorID: Int) - extends NDArrayIR - - final case class NDArrayReindex(nd: IR, indexExpr: IndexedSeq[Int]) extends NDArrayIR - - final case class NDArrayAgg(nd: IR, axes: IndexedSeq[Int]) extends IR - - final case class NDArrayWrite(nd: IR, path: IR) extends IR - - final case class NDArrayMatMul(l: IR, r: IR, errorID: Int) extends NDArrayIR - - object NDArrayQR { - def pType(mode: String, req: Boolean): PType = { - mode match { - case "r" => PCanonicalNDArray(PFloat64Required, 2, req) - case "raw" => PCanonicalTuple( - req, - PCanonicalNDArray(PFloat64Required, 2, true), - PCanonicalNDArray(PFloat64Required, 1, true), - ) - case "reduced" => PCanonicalTuple( - req, - PCanonicalNDArray(PFloat64Required, 2, true), - PCanonicalNDArray(PFloat64Required, 2, true), - ) - case "complete" => PCanonicalTuple( - req, - PCanonicalNDArray(PFloat64Required, 2, true), - PCanonicalNDArray(PFloat64Required, 2, true), - ) - } - } - } - - object NDArraySVD { - def pTypes(computeUV: Boolean, req: Boolean): PType = { - if (computeUV) { - PCanonicalTuple( - req, - PCanonicalNDArray(PFloat64Required, 2, true), - PCanonicalNDArray(PFloat64Required, 1, true), - PCanonicalNDArray(PFloat64Required, 2, true), - ) - } else { - PCanonicalNDArray(PFloat64Required, 1, req) - } - } - } - - object NDArrayInv { - val pType = PCanonicalNDArray(PFloat64Required, 2) - } - - final case class NDArrayQR(nd: IR, mode: String, errorID: Int) extends IR - - final case class NDArraySVD(nd: IR, fullMatrices: Boolean, computeUV: Boolean, errorID: Int) - extends IR - - object NDArrayEigh { - def pTypes(eigvalsOnly: Boolean, req: Boolean): PType = - if (eigvalsOnly) { - PCanonicalNDArray(PFloat64Required, 1, req) - } else { - PCanonicalTuple( - req, - PCanonicalNDArray(PFloat64Required, 1, true), - PCanonicalNDArray(PFloat64Required, 2, true), - ) - } - } - - final case class NDArrayEigh(nd: IR, eigvalsOnly: Boolean, errorID: Int) extends IR - - final case class NDArrayInv(nd: IR, errorID: Int) extends IR - - final case class AggFilter(cond: IR, aggIR: IR, isScan: Boolean) extends IR - - final case class AggExplode(array: IR, name: Name, aggBody: IR, isScan: Boolean) extends IR - - final case class AggGroupBy(key: IR, aggIR: IR, isScan: Boolean) extends IR - - final case class AggArrayPerElement( - a: IR, - elementName: Name, - indexName: Name, - aggBody: IR, - knownLength: Option[IR], - isScan: Boolean, - ) extends IR - - object ApplyAggOp { - def apply(op: AggOp, initOpArgs: IR*)(seqOpArgs: IR*): ApplyAggOp = - ApplyAggOp( - initOpArgs.toIndexedSeq, - seqOpArgs.toIndexedSeq, - AggSignature(op, initOpArgs.map(_.typ), seqOpArgs.map(_.typ)), - ) - } - - final case class ApplyAggOp( - initOpArgs: IndexedSeq[IR], - seqOpArgs: IndexedSeq[IR], - aggSig: AggSignature, - ) extends IR { - - def nSeqOpArgs = seqOpArgs.length - - def nInitArgs = initOpArgs.length - - def op: AggOp = aggSig.op - } - - object AggFold { - - def min(element: IR, sortFields: IndexedSeq[SortField]): IR = { - val elementType = element.typ.asInstanceOf[TStruct] - val keyType = elementType.select(sortFields.map(_.field))._1 - minAndMaxHelper(element, keyType, StructLT(keyType, sortFields)) - } - - def max(element: IR, sortFields: IndexedSeq[SortField]): IR = { - val elementType = element.typ.asInstanceOf[TStruct] - val keyType = elementType.select(sortFields.map(_.field))._1 - minAndMaxHelper(element, keyType, StructGT(keyType, sortFields)) - } - - def all(element: IR): IR = - aggFoldIR(True()) { accum => - ApplySpecial("land", Seq.empty[Type], Seq(accum, element), TBoolean, ErrorIDs.NO_ERROR) - } { (accum1, accum2) => - ApplySpecial("land", Seq.empty[Type], Seq(accum1, accum2), TBoolean, ErrorIDs.NO_ERROR) - } - - private def minAndMaxHelper(element: IR, keyType: TStruct, comp: ComparisonOp[Boolean]): IR = { - val keyFields = keyType.fields.map(_.name) - - val minAndMaxZero = NA(keyType) - val aggFoldMinAccumName1 = freshName() - val aggFoldMinAccumName2 = freshName() - val aggFoldMinAccumRef1 = Ref(aggFoldMinAccumName1, keyType) - val aggFoldMinAccumRef2 = Ref(aggFoldMinAccumName2, keyType) - val minSeq = bindIR(SelectFields(element, keyFields)) { keyOfCurElementRef => - If( - IsNA(aggFoldMinAccumRef1), - keyOfCurElementRef, - If( - ApplyComparisonOp(comp, aggFoldMinAccumRef1, keyOfCurElementRef), - aggFoldMinAccumRef1, - keyOfCurElementRef, - ), - ) - } - val minComb = - If( - IsNA(aggFoldMinAccumRef1), - aggFoldMinAccumRef2, - If( - ApplyComparisonOp(comp, aggFoldMinAccumRef1, aggFoldMinAccumRef2), - aggFoldMinAccumRef1, - aggFoldMinAccumRef2, - ), - ) - - AggFold(minAndMaxZero, minSeq, minComb, aggFoldMinAccumName1, aggFoldMinAccumName2, false) - } - } - - final case class AggFold( - zero: IR, - seqOp: IR, - combOp: IR, - accumName: Name, - otherAccumName: Name, - isScan: Boolean, - ) extends IR - - object ApplyScanOp { - def apply(op: AggOp, initOpArgs: IR*)(seqOpArgs: IR*): ApplyScanOp = - ApplyScanOp( - initOpArgs.toIndexedSeq, - seqOpArgs.toIndexedSeq, - AggSignature(op, initOpArgs.map(_.typ), seqOpArgs.map(_.typ)), - ) - } - - final case class ApplyScanOp( - initOpArgs: IndexedSeq[IR], - seqOpArgs: IndexedSeq[IR], - aggSig: AggSignature, - ) extends IR { - - def nSeqOpArgs = seqOpArgs.length - - def nInitArgs = initOpArgs.length - - def op: AggOp = aggSig.op - } - - final case class InitOp(i: Int, args: IndexedSeq[IR], aggSig: PhysicalAggSig) extends IR - - final case class SeqOp(i: Int, args: IndexedSeq[IR], aggSig: PhysicalAggSig) extends IR - - final case class CombOp(i1: Int, i2: Int, aggSig: PhysicalAggSig) extends IR - - object ResultOp { - def makeTuple(aggs: IndexedSeq[PhysicalAggSig]) = - MakeTuple.ordered(aggs.zipWithIndex.map { case (aggSig, index) => - ResultOp(index, aggSig) - }) - } - - final case class ResultOp(idx: Int, aggSig: PhysicalAggSig) extends IR - - final private[ir] case class CombOpValue(i: Int, value: IR, aggSig: PhysicalAggSig) extends IR - - final case class AggStateValue(i: Int, aggSig: AggStateSig) extends IR - - final case class InitFromSerializedValue(i: Int, value: IR, aggSig: AggStateSig) extends IR - - final case class SerializeAggs( - startIdx: Int, - serializedIdx: Int, - spec: BufferSpec, - aggSigs: IndexedSeq[AggStateSig], - ) extends IR - - final case class DeserializeAggs( - startIdx: Int, - serializedIdx: Int, - spec: BufferSpec, - aggSigs: IndexedSeq[AggStateSig], - ) extends IR - - final case class RunAgg(body: IR, result: IR, signature: IndexedSeq[AggStateSig]) extends IR - - final case class RunAggScan( - array: IR, - name: Name, - init: IR, - seqs: IR, - result: IR, - signature: IndexedSeq[AggStateSig], - ) extends IR - - object Begin { - def apply(xs: IndexedSeq[IR]): IR = - if (xs.isEmpty) - Void() - else - Let(xs.init.map(x => (freshName(), x)), xs.last) - } - - final case class Begin(xs: IndexedSeq[IR]) extends IR - - final case class MakeStruct(fields: IndexedSeq[(String, IR)]) extends IR - - final case class SelectFields(old: IR, fields: IndexedSeq[String]) extends IR - - object InsertFields { - def apply(old: IR, fields: IndexedSeq[(String, IR)]): InsertFields = - InsertFields(old, fields, None) - } - - final case class InsertFields( - old: IR, - fields: IndexedSeq[(String, IR)], - fieldOrder: Option[IndexedSeq[String]], - ) extends TypedIR[TStruct] - object GetFieldByIdx { def apply(s: IR, field: Int): IR = (s.typ: @unchecked) match { @@ -1004,80 +224,7 @@ package defs { } } - final case class GetField(o: IR, name: String) extends IR - - object MakeTuple { - def ordered(types: IndexedSeq[IR]): MakeTuple = MakeTuple(types.zipWithIndex.map { - case (ir, i) => - (i, ir) - }) - } - - final case class MakeTuple(fields: IndexedSeq[(Int, IR)]) extends IR - - final case class GetTupleElement(o: IR, idx: Int) extends IR - - object In { - def apply(i: Int, typ: Type): In = In( - i, - SingleCodeEmitParamType( - false, - typ match { - case TInt32 => Int32SingleCodeType - case TInt64 => Int64SingleCodeType - case TFloat32 => Float32SingleCodeType - case TFloat64 => Float64SingleCodeType - case TBoolean => BooleanSingleCodeType - case _: TStream => throw new UnsupportedOperationException - case t => PTypeReferenceSingleCodeType(PType.canonical(t)) - }, - ), - ) - } - - // Function Input - final case class In(i: Int, _typ: EmitParamType) extends IR - - // FIXME: should be type any - object Die { - def apply(message: String, typ: Type): Die = Die(Str(message), typ, ErrorIDs.NO_ERROR) - - def apply(message: String, typ: Type, errorId: Int): Die = Die(Str(message), typ, errorId) - } - - /** the Trap node runs the `child` node with an exception handler. If the child throws a - * HailException (user exception), then we return the tuple ((msg, errorId), NA). If the child - * throws any other exception, we raise that exception. If the child does not throw, then we - * return the tuple (NA, child value). - */ - final case class Trap(child: IR) extends IR - - final case class Die(message: IR, _typ: Type, errorId: Int) extends IR - - final case class ConsoleLog(message: IR, result: IR) extends IR - - final case class ApplyIR( - function: String, - typeArgs: Seq[Type], - args: Seq[IR], - returnType: Type, - errorID: Int, - ) extends IR { - var conversion: (Seq[Type], Seq[IR], Int) => IR = _ - var inline: Boolean = _ - - private lazy val refs = args.map(a => Ref(freshName(), a.typ)).toArray - lazy val body: IR = conversion(typeArgs, refs, errorID).deepCopy() - lazy val refIdx: Map[Name, Int] = refs.map(_.name).zipWithIndex.toMap - - lazy val explicitNode: IR = { - val ir = Let(refs.map(_.name).zip(args), body) - assert(ir.typ == returnType) - ir - } - } - - sealed abstract class AbstractApplyNode[F <: JVMFunction] extends IR { + trait AbstractApplyNode[F <: JVMFunction] extends IR { def function: String def args: Seq[IR] @@ -1093,90 +240,6 @@ package defs { .asInstanceOf[F] } - final case class Apply( - function: String, - typeArgs: Seq[Type], - args: Seq[IR], - returnType: Type, - errorID: Int, - ) extends AbstractApplyNode[UnseededMissingnessObliviousJVMFunction] - - final case class ApplySeeded( - function: String, - _args: Seq[IR], - rngState: IR, - staticUID: Long, - returnType: Type, - ) extends AbstractApplyNode[UnseededMissingnessObliviousJVMFunction] { - val args = rngState +: _args - val typeArgs: Seq[Type] = Seq.empty[Type] - } - - final case class ApplySpecial( - function: String, - typeArgs: Seq[Type], - args: Seq[IR], - returnType: Type, - errorID: Int, - ) extends AbstractApplyNode[UnseededMissingnessAwareJVMFunction] - - final case class LiftMeOut(child: IR) extends IR - - final case class TableCount(child: TableIR) extends IR - - final case class MatrixCount(child: MatrixIR) extends IR - - final case class TableAggregate(child: TableIR, query: IR) extends IR - - final case class MatrixAggregate(child: MatrixIR, query: IR) extends IR - - final case class TableWrite(child: TableIR, writer: TableWriter) extends IR - - final case class TableMultiWrite( - _children: IndexedSeq[TableIR], - writer: WrappedMatrixNativeMultiWriter, - ) extends IR - - final case class TableGetGlobals(child: TableIR) extends IR - - final case class TableCollect(child: TableIR) extends IR - - final case class MatrixWrite(child: MatrixIR, writer: MatrixWriter) extends IR - - final case class MatrixMultiWrite( - _children: IndexedSeq[MatrixIR], - writer: MatrixNativeMultiWriter, - ) extends IR - - final case class TableToValueApply(child: TableIR, function: TableToValueFunction) extends IR - - final case class MatrixToValueApply(child: MatrixIR, function: MatrixToValueFunction) extends IR - - final case class BlockMatrixToValueApply( - child: BlockMatrixIR, - function: BlockMatrixToValueFunction, - ) extends IR - - final case class BlockMatrixCollect(child: BlockMatrixIR) extends NDArrayIR - - final case class BlockMatrixWrite(child: BlockMatrixIR, writer: BlockMatrixWriter) extends IR - - final case class BlockMatrixMultiWrite( - blockMatrices: IndexedSeq[BlockMatrixIR], - writer: BlockMatrixMultiWriter, - ) extends IR - - final case class CollectDistributedArray( - contexts: IR, - globals: IR, - cname: Name, - gname: Name, - body: IR, - dynamicID: IR, - staticID: String, - tsd: Option[TableStageDependency] = None, - ) extends IR - object PartitionReader { implicit val formats: Formats = new DefaultFormats() { @@ -1395,25 +458,10 @@ package defs { final case class SimpleMetadataWriter(val annotationType: Type) extends MetadataWriter { def writeMetadata(writeAnnotations: => IEmitCode, cb: EmitCodeBuilder, region: Value[Region]) - : Unit = + : Unit = writeAnnotations.consume(cb, {}, _ => ()) } - final case class ReadPartition(context: IR, rowType: TStruct, reader: PartitionReader) extends IR - - final case class WritePartition(value: IR, writeCtx: IR, writer: PartitionWriter) extends IR - - final case class WriteMetadata(writeAnnotations: IR, writer: MetadataWriter) extends IR - - final case class ReadValue(path: IR, reader: ValueReader, requestedType: Type) extends IR - - final case class WriteValue( - value: IR, - path: IR, - writer: ValueWriter, - stagingFile: Option[IR] = None, - ) extends IR - class PrimitiveIR(val self: IR) extends AnyVal { def +(other: IR): IR = { assert(self.typ == other.typ) @@ -1463,4 +511,394 @@ package defs { object ErrorIDs { val NO_ERROR = -1 } + + package exts { + + object UUID4Ext { + def apply(): UUID4 = UUID4(genUID()) + } + + object MakeArrayExt { + def apply(args: IR*): MakeArray = { + assert(args.nonEmpty) + MakeArray(args.toFastSeq, TArray(args.head.typ)) + } + + def unify(ctx: ExecuteContext, args: IndexedSeq[IR], requestedType: TArray = null) + : MakeArray = { + assert(requestedType != null || args.nonEmpty) + + if (args.nonEmpty) + if (args.forall(_.typ == args.head.typ)) + return MakeArray(args, TArray(args.head.typ)) + + MakeArray( + args.map { arg => + val upcast = PruneDeadFields.upcast(ctx, arg, requestedType.elementType) + assert(upcast.typ == requestedType.elementType) + upcast + }, + requestedType, + ) + } + } + + object LiteralExt { + def coerce(t: Type, x: Any): IR = { + if (x == null) + return NA(t) + t match { + case TInt32 => I32(x.asInstanceOf[Number].intValue()) + case TInt64 => I64(x.asInstanceOf[Number].longValue()) + case TFloat32 => F32(x.asInstanceOf[Number].floatValue()) + case TFloat64 => F64(x.asInstanceOf[Number].doubleValue()) + case TBoolean => if (x.asInstanceOf[Boolean]) True() else False() + case TString => Str(x.asInstanceOf[String]) + case _ => Literal(t, x) + } + } + } + + object EncodedLiteralExt { + def apply(codec: AbstractTypedCodecSpec, value: Array[Array[Byte]]): EncodedLiteral = + EncodedLiteral(codec, new WrappedByteArrays(value)) + + def fromPTypeAndAddress(pt: PType, addr: Long, ctx: ExecuteContext): IR = { + pt match { + case _: PInt32 => I32(Region.loadInt(addr)) + case _: PInt64 => I64(Region.loadLong(addr)) + case _: PFloat32 => F32(Region.loadFloat(addr)) + case _: PFloat64 => F64(Region.loadDouble(addr)) + case _: PBoolean => if (Region.loadBoolean(addr)) True() else False() + case ts: PString => Str(ts.loadString(addr)) + case _ => + val etype = EType.defaultFromPType(ctx, pt) + val codec = TypedCodecSpec(etype, pt.virtualType, BufferSpec.wireSpec) + val bytes = codec.encodeArrays(ctx, pt, addr) + EncodedLiteral(codec, bytes) + } + } + } + + object BlockExt { + object Insert { + def unapply(bindings: IndexedSeq[Binding]) + : Option[(IndexedSeq[Binding], Binding, IndexedSeq[Binding])] = { + val idx = bindings.indexWhere(_.value.isInstanceOf[InsertFields]) + if (idx == -1) None else Some((bindings.take(idx), bindings(idx), bindings.drop(idx + 1))) + } + } + + object Nested { + def unapply(bindings: IndexedSeq[Binding]): Option[(Int, IndexedSeq[Binding])] = { + val idx = bindings.indexWhere(_.value.isInstanceOf[Block]) + if (idx == -1) None else Some((idx, bindings)) + } + } + } + + object MakeStreamExt { + def unify( + ctx: ExecuteContext, + args: IndexedSeq[IR], + requiresMemoryManagementPerElement: Boolean = false, + requestedType: TStream = null, + ): MakeStream = { + assert(requestedType != null || args.nonEmpty) + + if (args.nonEmpty) + if (args.forall(_.typ == args.head.typ)) + return MakeStream(args, TStream(args.head.typ), requiresMemoryManagementPerElement) + + MakeStream( + args.map { arg => + val upcast = PruneDeadFields.upcast(ctx, arg, requestedType.elementType) + assert(upcast.typ == requestedType.elementType) + upcast + }, + requestedType, + requiresMemoryManagementPerElement, + ) + } + } + + object ArraySortExt { + def apply(a: IR, ascending: IR = True(), onKey: Boolean = false): ArraySort = { + val l = freshName() + val r = freshName() + val atyp = tcoerce[TStream](a.typ) + val compare = if (onKey) { + val elementType = atyp.elementType.asInstanceOf[TBaseStruct] + elementType match { + case _: TStruct => + val elt = tcoerce[TStruct](atyp.elementType) + ApplyComparisonOp( + Compare(elt.types(0)), + GetField(Ref(l, elt), elt.fieldNames(0)), + GetField(Ref(r, atyp.elementType), elt.fieldNames(0)), + ) + case _: TTuple => + val elt = tcoerce[TTuple](atyp.elementType) + ApplyComparisonOp( + Compare(elt.types(0)), + GetTupleElement(Ref(l, elt), elt.fields(0).index), + GetTupleElement(Ref(r, atyp.elementType), elt.fields(0).index), + ) + } + } else { + ApplyComparisonOp( + Compare(atyp.elementType), + Ref(l, atyp.elementType), + Ref(r, atyp.elementType), + ) + } + + ArraySort(a, l, r, If(ascending, compare < 0, compare > 0)) + } + } + + object StreamFold2Ext { + def apply(a: StreamFold): StreamFold2 = + StreamFold2( + a.a, + FastSeq((a.accumName, a.zero)), + a.valueName, + FastSeq(a.body), + Ref(a.accumName, a.zero.typ), + ) + } + + final class StreamJoinRightDistinctExt(val node: StreamJoinRightDistinct) extends AnyVal { + import node._ + + def isIntervalJoin: Boolean = { + if (rKey.size != 1) return false + val lKeyTyp = tcoerce[TStruct](tcoerce[TStream](left.typ).elementType).fieldType(lKey(0)) + val rKeyTyp = tcoerce[TStruct](tcoerce[TStream](right.typ).elementType).fieldType(rKey(0)) + + rKeyTyp.isInstanceOf[TInterval] && lKeyTyp != rKeyTyp + } + } + + object MakeNDArrayExt { + def fill(elt: IR, shape: IndexedSeq[IR], rowMajor: IR): MakeNDArray = { + val flatSize: IR = if (shape.nonEmpty) + shape.reduce((l, r) => l * r) + else + 0L + MakeNDArray( + ToArray(mapIR(rangeIR(flatSize.toI))(_ => elt)), + MakeTuple.ordered(shape), + rowMajor, + ErrorIDs.NO_ERROR, + ) + } + } + + object NDArrayQRExt { + def pType(mode: String, req: Boolean): PType = { + mode match { + case "r" => PCanonicalNDArray(PFloat64Required, 2, req) + case "raw" => PCanonicalTuple( + req, + PCanonicalNDArray(PFloat64Required, 2, true), + PCanonicalNDArray(PFloat64Required, 1, true), + ) + case "reduced" => PCanonicalTuple( + req, + PCanonicalNDArray(PFloat64Required, 2, true), + PCanonicalNDArray(PFloat64Required, 2, true), + ) + case "complete" => PCanonicalTuple( + req, + PCanonicalNDArray(PFloat64Required, 2, true), + PCanonicalNDArray(PFloat64Required, 2, true), + ) + } + } + } + + object NDArraySVDExt { + def pTypes(computeUV: Boolean, req: Boolean): PType = { + if (computeUV) { + PCanonicalTuple( + req, + PCanonicalNDArray(PFloat64Required, 2, true), + PCanonicalNDArray(PFloat64Required, 1, true), + PCanonicalNDArray(PFloat64Required, 2, true), + ) + } else { + PCanonicalNDArray(PFloat64Required, 1, req) + } + } + } + + object NDArrayEighExt { + def pTypes(eigvalsOnly: Boolean, req: Boolean): PType = + if (eigvalsOnly) { + PCanonicalNDArray(PFloat64Required, 1, req) + } else { + PCanonicalTuple( + req, + PCanonicalNDArray(PFloat64Required, 1, true), + PCanonicalNDArray(PFloat64Required, 2, true), + ) + } + } + + object NDArrayInvExt { + val pType = PCanonicalNDArray(PFloat64Required, 2) + } + + object ApplyAggOpExt { + def apply(op: AggOp, initOpArgs: IR*)(seqOpArgs: IR*): ApplyAggOp = + ApplyAggOp( + initOpArgs.toIndexedSeq, + seqOpArgs.toIndexedSeq, + AggSignature(op, initOpArgs.map(_.typ), seqOpArgs.map(_.typ)), + ) + } + + final class ApplyAggOpExt(val node: ApplyAggOp) extends AnyVal { + import node._ + + def nSeqOpArgs = seqOpArgs.length + + def nInitArgs = initOpArgs.length + + def op: AggOp = aggSig.op + } + + object AggFoldExt { + def min(element: IR, sortFields: IndexedSeq[SortField]): IR = { + val elementType = element.typ.asInstanceOf[TStruct] + val keyType = elementType.select(sortFields.map(_.field))._1 + minAndMaxHelper(element, keyType, StructLT(keyType, sortFields)) + } + + def max(element: IR, sortFields: IndexedSeq[SortField]): IR = { + val elementType = element.typ.asInstanceOf[TStruct] + val keyType = elementType.select(sortFields.map(_.field))._1 + minAndMaxHelper(element, keyType, StructGT(keyType, sortFields)) + } + + def all(element: IR): IR = + aggFoldIR(True()) { accum => + ApplySpecial( + "land", + Seq.empty[Type], + FastSeq(accum, element), + TBoolean, + ErrorIDs.NO_ERROR, + ) + } { (accum1, accum2) => + ApplySpecial( + "land", + Seq.empty[Type], + FastSeq(accum1, accum2), + TBoolean, + ErrorIDs.NO_ERROR, + ) + } + + private def minAndMaxHelper(element: IR, keyType: TStruct, comp: ComparisonOp[Boolean]) + : IR = { + val keyFields = keyType.fields.map(_.name) + + val minAndMaxZero = NA(keyType) + val aggFoldMinAccumName1 = freshName() + val aggFoldMinAccumName2 = freshName() + val aggFoldMinAccumRef1 = Ref(aggFoldMinAccumName1, keyType) + val aggFoldMinAccumRef2 = Ref(aggFoldMinAccumName2, keyType) + val minSeq = bindIR(SelectFields(element, keyFields)) { keyOfCurElementRef => + If( + IsNA(aggFoldMinAccumRef1), + keyOfCurElementRef, + If( + ApplyComparisonOp(comp, aggFoldMinAccumRef1, keyOfCurElementRef), + aggFoldMinAccumRef1, + keyOfCurElementRef, + ), + ) + } + val minComb = + If( + IsNA(aggFoldMinAccumRef1), + aggFoldMinAccumRef2, + If( + ApplyComparisonOp(comp, aggFoldMinAccumRef1, aggFoldMinAccumRef2), + aggFoldMinAccumRef1, + aggFoldMinAccumRef2, + ), + ) + + AggFold(minAndMaxZero, minSeq, minComb, aggFoldMinAccumName1, aggFoldMinAccumName2, false) + } + } + + object ApplyScanOpExt { + def apply(op: AggOp, initOpArgs: IR*)(seqOpArgs: IR*): ApplyScanOp = + ApplyScanOp( + initOpArgs.toIndexedSeq, + seqOpArgs.toIndexedSeq, + AggSignature(op, initOpArgs.map(_.typ), seqOpArgs.map(_.typ)), + ) + } + + final class ApplyScanOpExt(val node: ApplyScanOp) extends AnyVal { + import node._ + + def nSeqOpArgs = seqOpArgs.length + + def nInitArgs = initOpArgs.length + + def op: AggOp = aggSig.op + } + + object ResultOpExt { + def makeTuple(aggs: IndexedSeq[PhysicalAggSig]) = + MakeTuple.ordered(aggs.zipWithIndex.map { case (aggSig, index) => + ResultOp(index, aggSig) + }) + } + + object BeginExt { + def apply(xs: IndexedSeq[IR]): IR = + if (xs.isEmpty) + Void() + else + Let(xs.init.map(x => (freshName(), x)), xs.last) + } + + object MakeTupleExt { + def ordered(types: IndexedSeq[IR]): MakeTuple = MakeTuple(types.zipWithIndex.map { + case (ir, i) => + (i, ir) + }) + } + + object InExt { + def apply(i: Int, typ: Type): In = In( + i, + SingleCodeEmitParamType( + false, + typ match { + case TInt32 => Int32SingleCodeType + case TInt64 => Int64SingleCodeType + case TFloat32 => Float32SingleCodeType + case TFloat64 => Float64SingleCodeType + case TBoolean => BooleanSingleCodeType + case _: TStream => throw new UnsupportedOperationException + case t => PTypeReferenceSingleCodeType(PType.canonical(t)) + }, + ), + ) + } + + object DieExt { + def apply(message: String, typ: Type): Die = Die(Str(message), typ, ErrorIDs.NO_ERROR) + + def apply(message: String, typ: Type, errorId: Int): Die = Die(Str(message), typ, errorId) + } + } } diff --git a/hail/hail/src/is/hail/expr/ir/Parser.scala b/hail/hail/src/is/hail/expr/ir/Parser.scala index 4c4a5203055..e147b5b9273 100644 --- a/hail/hail/src/is/hail/expr/ir/Parser.scala +++ b/hail/hail/src/is/hail/expr/ir/Parser.scala @@ -798,7 +798,7 @@ object IRParser { def apply_like( env: IRParserEnvironment, - cons: (String, Seq[Type], Seq[IR], Type, Int) => IR, + cons: (String, Seq[Type], IndexedSeq[IR], Type, Int) => IR, )( it: TokenIterator ): StackFrame[IR] = { diff --git a/hail/hail/src/is/hail/expr/ir/TypeCheck.scala b/hail/hail/src/is/hail/expr/ir/TypeCheck.scala index b6636ec9c84..228980bf562 100644 --- a/hail/hail/src/is/hail/expr/ir/TypeCheck.scala +++ b/hail/hail/src/is/hail/expr/ir/TypeCheck.scala @@ -345,7 +345,7 @@ object TypeCheck { assert(key.forall(structType.hasField)) case x @ StreamMap(a, _, body) => assert(a.typ.isInstanceOf[TStream]) - assert(x.elementTyp == body.typ) + assert(x.typ.elementType == body.typ) case x @ StreamZip(as, names, body, _, _) => assert(as.length == names.length) assert(x.typ.elementType == body.typ) diff --git a/hail/hail/src/is/hail/expr/ir/functions/ArrayFunctions.scala b/hail/hail/src/is/hail/expr/ir/functions/ArrayFunctions.scala index 582e238668b..7cada291e5b 100644 --- a/hail/hail/src/is/hail/expr/ir/functions/ArrayFunctions.scala +++ b/hail/hail/src/is/hail/expr/ir/functions/ArrayFunctions.scala @@ -55,14 +55,14 @@ object ArrayFunctions extends RegistryFunctions { tnum("T"), TFloat64, (ir1: IR, ir2: IR, errorID: Int) => - Apply("pow", Seq(), Seq(ir1, ir2), TFloat64, errorID), + Apply("pow", Seq(), FastSeq(ir1, ir2), TFloat64, errorID), ), ( "mod", tnum("T"), tv("T"), (ir1: IR, ir2: IR, errorID: Int) => - Apply("mod", Seq(), Seq(ir1, ir2), ir2.typ, errorID), + Apply("mod", Seq(), FastSeq(ir1, ir2), ir2.typ, errorID), ), ) diff --git a/hail/hail/src/is/hail/expr/ir/functions/Functions.scala b/hail/hail/src/is/hail/expr/ir/functions/Functions.scala index da60de9b5ab..9154a0634b1 100644 --- a/hail/hail/src/is/hail/expr/ir/functions/Functions.scala +++ b/hail/hail/src/is/hail/expr/ir/functions/Functions.scala @@ -178,7 +178,7 @@ object IRFunctionRegistry { : Option[(Seq[IR], IR) => IR] = lookupFunction(name, returnType, Array.empty[Type], TRNGState +: arguments) .map { f => (irArguments: Seq[IR], rngState: IR) => - ApplySeeded(name, irArguments, rngState, staticUID, f.returnType.subst()) + ApplySeeded(name, irArguments.toFastSeq, rngState, staticUID, f.returnType.subst()) } def lookupUnseeded(name: String, returnType: Type, arguments: Seq[Type]) @@ -194,7 +194,7 @@ object IRFunctionRegistry { val validIR: Option[IRFunctionImplementation] = lookupIR(name, returnType, typeParameters, arguments).map { case ((_, _, _, inline), conversion) => (typeParametersPassed, args, errorID) => - val x = ApplyIR(name, typeParametersPassed, args, returnType, errorID) + val x = ApplyIR(name, typeParametersPassed, args.toFastSeq, returnType, errorID) x.conversion = conversion x.inline = inline x @@ -205,9 +205,21 @@ object IRFunctionRegistry { { (irValueParametersTypes: Seq[Type], irArguments: Seq[IR], errorID: Int) => f match { case _: UnseededMissingnessObliviousJVMFunction => - Apply(name, irValueParametersTypes, irArguments, f.returnType.subst(), errorID) + Apply( + name, + irValueParametersTypes, + irArguments.toFastSeq, + f.returnType.subst(), + errorID, + ) case _: UnseededMissingnessAwareJVMFunction => - ApplySpecial(name, irValueParametersTypes, irArguments, f.returnType.subst(), errorID) + ApplySpecial( + name, + irValueParametersTypes, + irArguments.toFastSeq, + f.returnType.subst(), + errorID, + ) } } } diff --git a/hail/hail/test/src/is/hail/expr/ir/FoldConstantsSuite.scala b/hail/hail/test/src/is/hail/expr/ir/FoldConstantsSuite.scala index 789ec049d24..91b34b45d1e 100644 --- a/hail/hail/test/src/is/hail/expr/ir/FoldConstantsSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/FoldConstantsSuite.scala @@ -2,15 +2,13 @@ package is.hail.expr.ir import is.hail.HailSuite import is.hail.types.virtual.{TFloat64, TInt32} -import is.hail.expr.ir.defs.{ - ApplySeeded, F64, RNGStateLiteral, Str, AggLet, I32, ApplyAggOp, I64, ApplyScanOp, -} - +import is.hail.expr.ir.defs.{AggLet, ApplyAggOp, ApplyScanOp, ApplySeeded, F64, I32, I64, RNGStateLiteral, Str} +import is.hail.utils.FastSeq import org.testng.annotations.{DataProvider, Test} class FoldConstantsSuite extends HailSuite { @Test def testRandomBlocksFolding(): Unit = { - val x = ApplySeeded("rand_norm", Seq(F64(0d), F64(0d)), RNGStateLiteral(), 0L, TFloat64) + val x = ApplySeeded("rand_norm", FastSeq(F64(0d), F64(0d)), RNGStateLiteral(), 0L, TFloat64) assert(FoldConstants(ctx, x) == x) } diff --git a/hail/hail/test/src/is/hail/expr/ir/EncodedLiteralSuite.scala b/hail/hail/test/src/is/hail/expr/ir/defs/EncodedLiteralSuite.scala similarity index 88% rename from hail/hail/test/src/is/hail/expr/ir/EncodedLiteralSuite.scala rename to hail/hail/test/src/is/hail/expr/ir/defs/EncodedLiteralSuite.scala index 78fb54ea776..53b5ad147db 100644 --- a/hail/hail/test/src/is/hail/expr/ir/EncodedLiteralSuite.scala +++ b/hail/hail/test/src/is/hail/expr/ir/defs/EncodedLiteralSuite.scala @@ -1,7 +1,6 @@ -package is.hail.expr.ir +package is.hail.expr.ir.defs import is.hail.HailSuite -import is.hail.expr.ir.defs.WrappedByteArrays import org.testng.annotations.Test