Skip to content

Commit

Permalink
[SPARK-50644][SQL] Read variant struct in Parquet reader
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

It adds support for variant struct in Parquet reader. The concept of variant struct was introduced in apache#49235. It includes all the extracted fields from a variant column that the query requests.

### Why are the changes needed?

By producing variant struct in Parquet reader, we can avoid reading/rebuilding the full variant and achieve more efficient variant processing.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Unit test.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#49263 from chenhao-db/spark_variant_struct_reader.

Authored-by: Chenhao Li <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
chenhao-db authored and cloud-fan committed Dec 24, 2024
1 parent 202b42e commit 2c1c4d2
Show file tree
Hide file tree
Showing 8 changed files with 820 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,8 @@ public static Variant rebuild(ShreddedRow row, VariantSchema schema) {
throw malformedVariant();
}
byte[] metadata = row.getBinary(schema.topLevelMetadataIdx);
if (schema.variantIdx >= 0 && schema.typedIdx < 0) {
// The variant is unshredded. We are not required to do anything special, but we can have an
// optimization to avoid `rebuild`.
if (schema.isUnshredded()) {
// `rebuild` is unnecessary for unshredded variant.
if (row.isNullAt(schema.variantIdx)) {
throw malformedVariant();
}
Expand All @@ -65,8 +64,8 @@ public static Variant rebuild(ShreddedRow row, VariantSchema schema) {
// Rebuild a variant value from the shredded data according to the reconstruction algorithm in
// https://github.com/apache/parquet-format/blob/master/VariantShredding.md.
// Append the result to `builder`.
private static void rebuild(ShreddedRow row, byte[] metadata, VariantSchema schema,
VariantBuilder builder) {
public static void rebuild(ShreddedRow row, byte[] metadata, VariantSchema schema,
VariantBuilder builder) {
int typedIdx = schema.typedIdx;
int variantIdx = schema.variantIdx;
if (typedIdx >= 0 && !row.isNullAt(typedIdx)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,12 @@ public VariantSchema(int typedIdx, int variantIdx, int topLevelMetadataIdx, int
this.arraySchema = arraySchema;
}

// Return whether the variant column is unshrededed. The user is not required to do anything
// special, but can have certain optimizations for unshrededed variant.
public boolean isUnshredded() {
return topLevelMetadataIdx >= 0 && variantIdx >= 0 && typedIdx < 0;
}

@Override
public String toString() {
return "VariantSchema{" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.VariantType;
import org.apache.spark.types.variant.VariantSchema;
import org.apache.spark.unsafe.types.VariantVal;

/**
* Contains necessary information representing a Parquet column, either of primitive or nested type.
Expand All @@ -49,6 +48,9 @@ final class ParquetColumnVector {
// contains only one child that reads the underlying file content. This `ParquetColumnVector`
// should assemble Spark variant values from the file content.
private VariantSchema variantSchema;
// Only meaningful if `variantSchema` is not null. See `SparkShreddingUtils.getFieldsToExtract`
// for its meaning.
private FieldToExtract[] fieldsToExtract;

/**
* Repetition & Definition levels
Expand Down Expand Up @@ -117,6 +119,7 @@ final class ParquetColumnVector {
fileContent, capacity, memoryMode, missingColumns, false, null);
children.add(contentVector);
variantSchema = SparkShreddingUtils.buildVariantSchema(fileContentCol.sparkType());
fieldsToExtract = SparkShreddingUtils.getFieldsToExtract(column.sparkType(), variantSchema);
repetitionLevels = contentVector.repetitionLevels;
definitionLevels = contentVector.definitionLevels;
} else if (isPrimitive) {
Expand Down Expand Up @@ -188,20 +191,11 @@ void assemble() {
if (variantSchema != null) {
children.get(0).assemble();
WritableColumnVector fileContent = children.get(0).getValueVector();
int numRows = fileContent.getElementsAppended();
vector.reset();
vector.reserve(numRows);
WritableColumnVector valueChild = vector.getChild(0);
WritableColumnVector metadataChild = vector.getChild(1);
for (int i = 0; i < numRows; ++i) {
if (fileContent.isNullAt(i)) {
vector.appendStruct(true);
} else {
vector.appendStruct(false);
VariantVal v = SparkShreddingUtils.rebuild(fileContent.getStruct(i), variantSchema);
valueChild.appendByteArray(v.getValue(), 0, v.getValue().length);
metadataChild.appendByteArray(v.getMetadata(), 0, v.getMetadata().length);
}
if (fieldsToExtract == null) {
SparkShreddingUtils.assembleVariantBatch(fileContent, vector, variantSchema);
} else {
SparkShreddingUtils.assembleVariantStructBatch(fileContent, vector, variantSchema,
fieldsToExtract);
}
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.datasources.VariantMetadata
import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf}
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -221,6 +222,9 @@ object ParquetReadSupport extends Logging {
clipParquetMapType(
parquetType.asGroupType(), t.keyType, t.valueType, caseSensitive, useFieldId)

case t: StructType if VariantMetadata.isVariantStruct(t) =>
clipVariantSchema(parquetType.asGroupType(), t)

case t: StructType =>
clipParquetGroup(parquetType.asGroupType(), t, caseSensitive, useFieldId)

Expand Down Expand Up @@ -390,6 +394,11 @@ object ParquetReadSupport extends Logging {
.named(parquetRecord.getName)
}

private def clipVariantSchema(parquetType: GroupType, variantStruct: StructType): GroupType = {
// TODO(SHREDDING): clip `parquetType` to retain the necessary columns.
parquetType
}

/**
* Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[StructType]].
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.datasources.DataSourceUtils
import org.apache.spark.sql.execution.datasources.{DataSourceUtils, VariantMetadata}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
Expand Down Expand Up @@ -498,6 +498,9 @@ private[parquet] class ParquetRowConverter(
case t: MapType =>
new ParquetMapConverter(parquetType.asGroupType(), t, updater)

case t: StructType if VariantMetadata.isVariantStruct(t) =>
new ParquetVariantConverter(t, parquetType.asGroupType(), updater)

case t: StructType =>
val wrappedUpdater = {
// SPARK-30338: avoid unnecessary InternalRow copying for nested structs:
Expand Down Expand Up @@ -536,12 +539,7 @@ private[parquet] class ParquetRowConverter(

case t: VariantType =>
if (SQLConf.get.getConf(SQLConf.VARIANT_ALLOW_READING_SHREDDED)) {
// Infer a Spark type from `parquetType`. This piece of code is copied from
// `ParquetArrayConverter`.
val messageType = Types.buildMessage().addField(parquetType).named("foo")
val column = new ColumnIOFactory().getColumnIO(messageType)
val parquetSparkType = schemaConverter.convertField(column.getChild(0)).sparkType
new ParquetVariantConverter(parquetType.asGroupType(), parquetSparkType, updater)
new ParquetVariantConverter(t, parquetType.asGroupType(), updater)
} else {
new ParquetUnshreddedVariantConverter(parquetType.asGroupType(), updater)
}
Expand Down Expand Up @@ -909,13 +907,14 @@ private[parquet] class ParquetRowConverter(

/** Parquet converter for Variant (shredded or unshredded) */
private final class ParquetVariantConverter(
parquetType: GroupType,
parquetSparkType: DataType,
updater: ParentContainerUpdater)
targetType: DataType, parquetType: GroupType, updater: ParentContainerUpdater)
extends ParquetGroupConverter(updater) {

private[this] var currentRow: Any = _
private[this] val parquetSparkType = SparkShreddingUtils.parquetTypeToSparkType(parquetType)
private[this] val variantSchema = SparkShreddingUtils.buildVariantSchema(parquetSparkType)
private[this] val fieldsToExtract =
SparkShreddingUtils.getFieldsToExtract(targetType, variantSchema)
// A struct converter that reads the underlying file data.
private[this] val fileConverter = new ParquetRowConverter(
schemaConverter,
Expand All @@ -932,7 +931,12 @@ private[parquet] class ParquetRowConverter(

override def end(): Unit = {
fileConverter.end()
val v = SparkShreddingUtils.rebuild(currentRow.asInstanceOf[InternalRow], variantSchema)
val row = currentRow.asInstanceOf[InternalRow]
val v = if (fieldsToExtract == null) {
SparkShreddingUtils.assembleVariant(row, variantSchema)
} else {
SparkShreddingUtils.assembleVariantStruct(row, variantSchema, fieldsToExtract)
}
updater.set(v)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.apache.parquet.schema.Type.Repetition._

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.execution.datasources.VariantMetadata
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -185,6 +186,9 @@ class ParquetToSparkSchemaConverter(
} else {
convertVariantField(groupColumn)
}
case groupColumn: GroupColumnIO if targetType.exists(VariantMetadata.isVariantStruct) =>
val col = convertGroupField(groupColumn)
col.copy(sparkType = targetType.get, variantFileType = Some(col))
case groupColumn: GroupColumnIO => convertGroupField(groupColumn, targetType)
}
}
Expand Down
Loading

0 comments on commit 2c1c4d2

Please sign in to comment.