Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support BloomFilterMightContain expression #8775

Merged
merged 6 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -18339,11 +18339,11 @@ are limited.
<td>S</td>
<td>S</td>
<td>S</td>
<td>S</td>
<td><b>NS</b></td>
<td><b>NS</b></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types BINARY, CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types CALENDAR, UDT</em></td>
<td><em>PS<br/>UTC is only supported TZ for child TIMESTAMP;<br/>unsupported child types CALENDAR, UDT</em></td>
<td><b>NS</b></td>
</tr>
<tr>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3464,8 +3464,8 @@ object GpuOverrides extends Logging {
expr[org.apache.spark.sql.execution.ScalarSubquery](
"Subquery that will return only one row and one column",
ExprChecks.projectOnly(
(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128
+ TypeSig.ARRAY + TypeSig.MAP + TypeSig.STRUCT).nested(),
(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.DECIMAL_128 + TypeSig.BINARY +
TypeSig.ARRAY + TypeSig.MAP + TypeSig.STRUCT).nested(),
TypeSig.all,
Nil, None),
(a, conf, p, r) =>
Expand Down Expand Up @@ -3535,7 +3535,8 @@ object GpuOverrides extends Logging {
// Shim expressions should be last to allow overrides with shim-specific versions
val expressions: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] =
commonExpressions ++ TimeStamp.getExprs ++ GpuHiveOverrides.exprs ++
ZOrderRules.exprs ++ DecimalArithmeticOverrides.exprs ++ SparkShimImpl.getExprs
ZOrderRules.exprs ++ DecimalArithmeticOverrides.exprs ++
BloomFilterShims.exprs ++ SparkShimImpl.getExprs

def wrapScan[INPUT <: Scan](
scan: INPUT,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "311"}
{"spark": "312"}
{"spark": "313"}
{"spark": "320"}
{"spark": "321"}
{"spark": "321cdh"}
{"spark": "321db" }
{"spark": "322"}
{"spark": "323"}
{"spark": "324"}
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids.ExprRule

import org.apache.spark.sql.catalyst.expressions.Expression


object BloomFilterShims {
val exprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = Map.empty
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "330"}
{"spark": "330cdh"}
{"spark": "330db"}
{"spark": "331"}
{"spark": "332"}
{"spark": "332db"}
{"spark": "333"}
{"spark": "340"}
{"spark": "341"}
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids

import java.io.DataInputStream

import ai.rapids.cudf.{BaseDeviceMemoryBuffer, ColumnVector, Cuda, DeviceMemoryBuffer, DType, HostMemoryBuffer}
import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource}
import com.nvidia.spark.rapids.jni.BloomFilter

import org.apache.spark.sql.types.{BinaryType, NullType}

/**
* GPU version of Spark's BloomFilterImpl.
* @param numHashes number of hash functions to use in the Bloom filter
* @param buffer device buffer containing the Bloom filter data in the Spark Bloom filter
* serialization format. The device buffer will be closed when this GpuBloomFilter
* instance is closed.
*/
class GpuBloomFilter(val numHashes: Int, buffer: DeviceMemoryBuffer) extends AutoCloseable {
revans2 marked this conversation as resolved.
Show resolved Hide resolved
private val spillableBuffer = SpillableBuffer(buffer, SpillPriorities.ACTIVE_ON_DECK_PRIORITY)
private val numFilterBits = buffer.getLength * 8

/**
* Given an input column of longs, return a boolean column with the same row count where each
* output row indicates whether the corresponding input row may have been placed into this
* Bloom filter. A false value indicates definitively that the value was not placed in the filter.
*/
def mightContainLong(col: ColumnVector): ColumnVector = {
require(col.getType == DType.INT64, s"expected longs, got ${col.getType}")
withResource(spillableBuffer.getDeviceBuffer()) { buffer =>
BloomFilter.probe(numHashes, numFilterBits, buffer, col)
}
}

override def close(): Unit = {
spillableBuffer.close()
}
}

object GpuBloomFilter {
// Spark serializes their bloom filters in a specific format, see BloomFilterImpl.readFrom.
// Data is written via DataOutputStream, so everything is big-endian.
// Byte Offset Size Description
// 0 4 Version ID (see Spark's BloomFilter.Version)
// 4 4 Number of hash functions
// 8 4 Number of longs, N
// 12 N*8 Bloom filter data buffer as longs
private val HEADER_SIZE = 12

// version numbers from BloomFilter.Version enum
private val VERSION_V1 = 1

def apply(s: GpuScalar): GpuBloomFilter = {
s.dataType match {
case BinaryType if s.isValid =>
withResource(s.getBase.getListAsColumnView) { childView =>
require(childView.getType == DType.UINT8, s"expected UINT8 got ${childView.getType}")
deserialize(childView.getData)
}
case BinaryType | NullType => null
case t => throw new IllegalArgumentException(s"Expected binary or null scalar, found $t")
}
}

def deserialize(data: BaseDeviceMemoryBuffer): GpuBloomFilter = {
revans2 marked this conversation as resolved.
Show resolved Hide resolved
// Sanity check bloom filter header
val totalLen = data.getLength
val bufferLen = totalLen - HEADER_SIZE
require(totalLen >= HEADER_SIZE, s"header size is $totalLen")
require(bufferLen % 8 == 0, "buffer length not a multiple of 8")
val numHashes = withResource(HostMemoryBuffer.allocate(HEADER_SIZE, false)) { hostHeader =>
hostHeader.copyFromMemoryBuffer(0, data, 0, HEADER_SIZE, Cuda.DEFAULT_STREAM)
parseHeader(hostHeader, bufferLen)
}
// TODO: Can we avoid this copy? Would either need the ability to release data buffers
// from scalars or make scalars spillable.
val filterBuffer = DeviceMemoryBuffer.allocate(bufferLen)
closeOnExcept(filterBuffer) { buf =>
buf.copyFromDeviceBufferAsync(0, data, HEADER_SIZE, buf.getLength, Cuda.DEFAULT_STREAM)
}
new GpuBloomFilter(numHashes, filterBuffer)
}

/**
* Parses the Spark Bloom filter serialization header performing sanity checks
* and retrieving the number of hash functions used for the filter.
* @param buffer serialized header data
* @param dataLen size of the serialized Bloom filter data without header
* @return number of hash functions used in the Bloom filter
*/
private def parseHeader(buffer: HostMemoryBuffer, dataLen: Long): Int = {
val in = new DataInputStream(new HostMemoryInputStream(buffer, buffer.getLength))
val version = in.readInt
require(version == VERSION_V1, s"unsupported serialization format version $version")
val numHashes = in.readInt()
val sizeFromHeader = in.readInt() * 8L
require(dataLen == sizeFromHeader,
s"data size from header is $sizeFromHeader, received $dataLen")
numHashes
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "330"}
{"spark": "330cdh"}
{"spark": "330db"}
{"spark": "331"}
{"spark": "332"}
{"spark": "332db"}
{"spark": "333"}
{"spark": "340"}
{"spark": "341"}
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids

import com.nvidia.spark.rapids.Arm.{withResource, withResourceIfAllowed}
import com.nvidia.spark.rapids.RapidsPluginImplicits.ReallyAGpuExpression
import com.nvidia.spark.rapids.shims.ShimBinaryExpression

import org.apache.spark.TaskContext
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField, PlanExpression}
import org.apache.spark.sql.catalyst.trees.TreePattern.OUTER_REFERENCE
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch

case class GpuBloomFilterMightContain(
bloomFilterExpression: Expression,
valueExpression: Expression)
extends ShimBinaryExpression with GpuExpression with AutoCloseable {

@transient private lazy val bloomFilter: GpuBloomFilter = {
Option(TaskContext.get).foreach(_.addTaskCompletionListener[Unit](_ => close()))
withResourceIfAllowed(bloomFilterExpression.columnarEvalAny(new ColumnarBatch(Array.empty))) {
case s: GpuScalar => GpuBloomFilter(s)
case x => throw new IllegalStateException(s"Expected GPU scalar, found $x")
}
}

override def nullable: Boolean = true

override def left: Expression = bloomFilterExpression

override def right: Expression = valueExpression

override def prettyName: String = "might_contain"

override def dataType: DataType = BooleanType

override def checkInputDataTypes(): TypeCheckResult = {
(left.dataType, right.dataType) match {
case (BinaryType, NullType) | (NullType, LongType) | (NullType, NullType) |
(BinaryType, LongType) =>
bloomFilterExpression match {
case e: Expression if e.foldable => TypeCheckResult.TypeCheckSuccess
case subquery: PlanExpression[_] if !subquery.containsPattern(OUTER_REFERENCE) =>
TypeCheckResult.TypeCheckSuccess
case GetStructField(subquery: PlanExpression[_], _, _)
if !subquery.containsPattern(OUTER_REFERENCE) =>
TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(s"The Bloom filter binary input to $prettyName " +
"should be either a constant value or a scalar subquery expression")
}
case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " +
s"been ${BinaryType.simpleString} followed by a value with ${LongType.simpleString}, " +
s"but it's [${left.dataType.catalogString}, ${right.dataType.catalogString}].")
}
}

override def columnarEval(batch: ColumnarBatch): GpuColumnVector = {
if (bloomFilter == null) {
GpuColumnVector.fromNull(batch.numRows(), dataType)
} else {
withResource(valueExpression.columnarEval(batch)) { value =>
if (value == null || value.dataType == NullType) {
GpuColumnVector.fromNull(batch.numRows(), dataType)
} else {
GpuColumnVector.from(bloomFilter.mightContainLong(value.getBase), BooleanType)
}
}
}
}

override def close(): Unit = {
jlowe marked this conversation as resolved.
Show resolved Hide resolved
if (bloomFilter != null) {
bloomFilter.close()
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*** spark-rapids-shim-json-lines
{"spark": "330"}
{"spark": "330cdh"}
{"spark": "330db"}
{"spark": "331"}
{"spark": "332"}
{"spark": "332db"}
{"spark": "333"}
{"spark": "340"}
{"spark": "341"}
spark-rapids-shim-json-lines ***/
package com.nvidia.spark.rapids.shims

import com.nvidia.spark.rapids._

import org.apache.spark.sql.catalyst.expressions._


object BloomFilterShims {
val exprs: Map[Class[_ <: Expression], ExprRule[_ <: Expression]] = {
Seq(
GpuOverrides.expr[BloomFilterMightContain](
"Bloom filter query",
ExprChecks.binaryProject(
TypeSig.BOOLEAN,
TypeSig.BOOLEAN,
("lhs", TypeSig.BINARY + TypeSig.NULL, TypeSig.BINARY + TypeSig.NULL),
("rhs", TypeSig.LONG + TypeSig.NULL, TypeSig.LONG + TypeSig.NULL)),
(a, conf, p, r) => new BinaryExprMeta[BloomFilterMightContain](a, conf, p, r) {
override def convertToGpu(lhs: Expression, rhs: Expression): GpuExpression =
GpuBloomFilterMightContain(lhs, rhs)
})
).map(r => (r.getClassFor.asSubclass(classOf[Expression]), r)).toMap
}
}
Loading