From 86f4057762499c2d234599b6726b75b1d3c00067 Mon Sep 17 00:00:00 2001 From: Chenhao Li Date: Sun, 7 Apr 2024 11:09:49 -0700 Subject: [PATCH] cast --- .../apache/spark/types/variant/Variant.java | 6 +- .../spark/types/variant/VariantBuilder.java | 397 ++++++++++++------ .../spark/sql/catalyst/expressions/Cast.scala | 7 + .../variant/VariantExpressionEvalUtils.scala | 70 ++- .../variant/VariantExpressionSuite.scala | 55 +++ 5 files changed, 396 insertions(+), 139 deletions(-) diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java b/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java index 4aeb2c6e14355..a705daaf323b2 100644 --- a/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java +++ b/common/variant/src/main/java/org/apache/spark/types/variant/Variant.java @@ -41,12 +41,12 @@ * define a new class to avoid depending on or modifying Spark. */ public final class Variant { - private final byte[] value; - private final byte[] metadata; + final byte[] value; + final byte[] metadata; // The variant value doesn't use the whole `value` binary, but starts from its `pos` index and // spans a size of `valueSize(value, pos)`. This design avoids frequent copies of the value binary // when reading a sub-variant in the array/object element. - private final int pos; + final int pos; public Variant(byte[] value, byte[] metadata) { this(value, metadata, 0); diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java b/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java index 21a12cbe9d714..fbb29ee694675 100644 --- a/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java +++ b/common/variant/src/main/java/org/apache/spark/types/variant/VariantBuilder.java @@ -31,6 +31,9 @@ import com.fasterxml.jackson.core.JsonParseException; import com.fasterxml.jackson.core.JsonToken; import com.fasterxml.jackson.core.exc.InputCoercionException; +import org.apache.spark.QueryContext; +import org.apache.spark.SparkRuntimeException; +import scala.collection.immutable.Map$; import static org.apache.spark.types.variant.VariantUtil.*; @@ -61,7 +64,7 @@ public static Variant parseJson(JsonParser parser) throws IOException { } // Build the variant metadata from `dictionaryKeys` and return the variant result. - private Variant result() { + public Variant result() { int numKeys = dictionaryKeys.size(); // Use long to avoid overflow in accumulating lengths. long dictionaryStringSize = 0; @@ -100,6 +103,248 @@ private Variant result() { return new Variant(Arrays.copyOfRange(writeBuffer, 0, writePos), metadata); } + public int getWritePos() { + return writePos; + } + + public void appendString(String str) { + byte[] text = str.getBytes(StandardCharsets.UTF_8); + boolean longStr = text.length > MAX_SHORT_STR_SIZE; + checkCapacity((longStr ? 1 + U32_SIZE : 1) + text.length); + if (longStr) { + writeBuffer[writePos++] = primitiveHeader(LONG_STR); + writeLong(writeBuffer, writePos, text.length, U32_SIZE); + writePos += U32_SIZE; + } else { + writeBuffer[writePos++] = shortStrHeader(text.length); + } + System.arraycopy(text, 0, writeBuffer, writePos, text.length); + writePos += text.length; + } + + public void appendNull() { + checkCapacity(1); + writeBuffer[writePos++] = primitiveHeader(NULL); + } + + public void appendBoolean(boolean b) { + checkCapacity(1); + writeBuffer[writePos++] = primitiveHeader(b ? TRUE : FALSE); + } + + public void appendLong(long l) { + checkCapacity(1 + 8); + if (l == (byte) l) { + writeBuffer[writePos++] = primitiveHeader(INT1); + writeLong(writeBuffer, writePos, l, 1); + writePos += 1; + } else if (l == (short) l) { + writeBuffer[writePos++] = primitiveHeader(INT2); + writeLong(writeBuffer, writePos, l, 2); + writePos += 2; + } else if (l == (int) l) { + writeBuffer[writePos++] = primitiveHeader(INT4); + writeLong(writeBuffer, writePos, l, 4); + writePos += 4; + } else { + writeBuffer[writePos++] = primitiveHeader(INT8); + writeLong(writeBuffer, writePos, l, 8); + writePos += 8; + } + } + + public void appendDouble(double d) { + checkCapacity(1 + 8); + writeBuffer[writePos++] = primitiveHeader(DOUBLE); + writeLong(writeBuffer, writePos, Double.doubleToLongBits(d), 8); + writePos += 8; + } + + public void appendDecimal(BigDecimal d) { + checkCapacity(2 + 16); + BigInteger unscaled = d.unscaledValue(); + if (d.scale() <= MAX_DECIMAL4_PRECISION && d.precision() <= MAX_DECIMAL4_PRECISION) { + writeBuffer[writePos++] = primitiveHeader(DECIMAL4); + writeBuffer[writePos++] = (byte) d.scale(); + writeLong(writeBuffer, writePos, unscaled.intValueExact(), 4); + writePos += 4; + } else if (d.scale() <= MAX_DECIMAL8_PRECISION && d.precision() <= MAX_DECIMAL8_PRECISION) { + writeBuffer[writePos++] = primitiveHeader(DECIMAL8); + writeBuffer[writePos++] = (byte) d.scale(); + writeLong(writeBuffer, writePos, unscaled.longValueExact(), 8); + writePos += 8; + } else { + assert d.scale() <= MAX_DECIMAL16_PRECISION && d.precision() <= MAX_DECIMAL16_PRECISION; + writeBuffer[writePos++] = primitiveHeader(DECIMAL16); + writeBuffer[writePos++] = (byte) d.scale(); + // `toByteArray` returns a big-endian representation. We need to copy it reversely and sign + // extend it to 16 bytes. + byte[] bytes = unscaled.toByteArray(); + for (int i = 0; i < bytes.length; ++i) { + writeBuffer[writePos + i] = bytes[bytes.length - 1 - i]; + } + byte sign = (byte) (bytes[0] < 0 ? -1 : 0); + for (int i = bytes.length; i < 16; ++i) { + writeBuffer[writePos + i] = sign; + } + writePos += 16; + } + } + + public void appendDate(int daysSinceEpoch) { + checkCapacity(1 + 4); + writeBuffer[writePos++] = primitiveHeader(DATE); + writeLong(writeBuffer, writePos, daysSinceEpoch, 4); + writePos += 4; + } + + public void appendTimestamp(long microsSinceEpoch) { + checkCapacity(1 + 8); + writeBuffer[writePos++] = primitiveHeader(TIMESTAMP); + writeLong(writeBuffer, writePos, microsSinceEpoch, 8); + writePos += 8; + } + + public void appendTimestampNtz(long microsSinceEpoch) { + checkCapacity(1 + 8); + writeBuffer[writePos++] = primitiveHeader(TIMESTAMP_NTZ); + writeLong(writeBuffer, writePos, microsSinceEpoch, 8); + writePos += 8; + } + + public void appendFloat(float f) { + checkCapacity(1 + 4); + writeBuffer[writePos++] = primitiveHeader(FLOAT); + writeLong(writeBuffer, writePos, Float.floatToIntBits(f), 8); + writePos += 4; + } + + public void appendBinary(byte[] binary) { + checkCapacity(1 + U32_SIZE + binary.length); + writeBuffer[writePos++] = primitiveHeader(LONG_STR); + writeLong(writeBuffer, writePos, binary.length, U32_SIZE); + writePos += U32_SIZE; + System.arraycopy(binary, 0, writeBuffer, writePos, binary.length); + writePos += binary.length; + } + + public int getOrInsertKey(String key) { + int id; + if (dictionary.containsKey(key)) { + id = dictionary.get(key); + } else { + id = dictionaryKeys.size(); + dictionary.put(key, id); + dictionaryKeys.add(key.getBytes(StandardCharsets.UTF_8)); + } + return id; + } + + public void finishWritingObject(int start, ArrayList fields) { + int dataSize = writePos - start; + int size = fields.size(); + Collections.sort(fields); + int maxId = size == 0 ? 0 : fields.get(0).id; + // Check for duplicate field keys. Only need to check adjacent key because they are sorted. + for (int i = 1; i < size; ++i) { + maxId = Math.max(maxId, fields.get(i).id); + String key = fields.get(i).key; + if (key.equals(fields.get(i - 1).key)) { + throw new SparkRuntimeException("VARIANT_DUPLICATE_KEY", + Map$.MODULE$.empty(), null, new QueryContext[]{}, ""); + } + } + boolean largeSize = size > U8_MAX; + int sizeBytes = largeSize ? U32_SIZE : 1; + int idSize = getIntegerSize(maxId); + int offsetSize = getIntegerSize(dataSize); + // The space for header byte, object size, id list, and offset list. + int headerSize = 1 + sizeBytes + size * idSize + (size + 1) * offsetSize; + checkCapacity(headerSize); + // Shift the just-written field data to make room for the object header section. + System.arraycopy(writeBuffer, start, writeBuffer, start + headerSize, dataSize); + writePos += headerSize; + writeBuffer[start] = objectHeader(largeSize, idSize, offsetSize); + writeLong(writeBuffer, start + 1, size, sizeBytes); + int idStart = start + 1 + sizeBytes; + int offsetStart = idStart + size * idSize; + for (int i = 0; i < size; ++i) { + writeLong(writeBuffer, idStart + i * idSize, fields.get(i).id, idSize); + writeLong(writeBuffer, offsetStart + i * offsetSize, fields.get(i).offset, offsetSize); + } + writeLong(writeBuffer, offsetStart + size * offsetSize, dataSize, offsetSize); + } + + public void finishWritingArray(int start, ArrayList offsets) { + int dataSize = writePos - start; + int size = offsets.size(); + boolean largeSize = size > U8_MAX; + int sizeBytes = largeSize ? U32_SIZE : 1; + int offsetSize = getIntegerSize(dataSize); + // The space for header byte, object size, and offset list. + int headerSize = 1 + sizeBytes + (size + 1) * offsetSize; + checkCapacity(headerSize); + // Shift the just-written field data to make room for the header section. + System.arraycopy(writeBuffer, start, writeBuffer, start + headerSize, dataSize); + writePos += headerSize; + writeBuffer[start] = arrayHeader(largeSize, offsetSize); + writeLong(writeBuffer, start + 1, size, sizeBytes); + int offsetStart = start + 1 + sizeBytes; + for (int i = 0; i < size; ++i) { + writeLong(writeBuffer, offsetStart + i * offsetSize, offsets.get(i), offsetSize); + } + writeLong(writeBuffer, offsetStart + size * offsetSize, dataSize, offsetSize); + } + + public void appendVariant(Variant v) { + appendVariantImpl(v.value, v.metadata, v.pos); + } + + private void appendVariantImpl(byte[] value, byte[] metadata, int pos) { + checkIndex(pos, value.length); + int basicType = value[pos] & BASIC_TYPE_MASK; + switch (basicType) { + case OBJECT: + handleObject(value, pos, (size, idSize, offsetSize, idStart, offsetStart, dataStart) -> { + ArrayList fields = new ArrayList<>(size); + int start = writePos; + for (int i = 0; i < size; ++i) { + int id = readUnsigned(value, idStart + idSize * i, idSize); + int offset = readUnsigned(value, offsetStart + offsetSize * i, offsetSize); + int elementPos = dataStart + offset; + String key = getMetadataKey(metadata, id); + int newId = getOrInsertKey(key); + fields.add(new FieldEntry(key, newId, writePos - start)); + appendVariantImpl(value, metadata, elementPos); + } + finishWritingObject(start, fields); + return null; + }); + break; + case ARRAY: + handleArray(value, pos, (size, offsetSize, offsetStart, dataStart) -> { + ArrayList offsets = new ArrayList<>(size); + int start = writePos; + for (int i = 0; i < size; ++i) { + int offset = readUnsigned(value, offsetStart + offsetSize * i, offsetSize); + int elementPos = dataStart + offset; + offsets.add(writePos - start); + appendVariantImpl(value, metadata, elementPos); + } + finishWritingArray(start, offsets); + return null; + }); + break; + default: + int size = valueSize(value, pos); + checkIndex(pos + size - 1, value.length); + checkCapacity(size); + System.arraycopy(value, pos, writeBuffer, writePos, size); + writePos += size; + break; + } + } + private void checkCapacity(int additional) { int required = writePos + additional; if (required > writeBuffer.length) { @@ -117,12 +362,12 @@ private void checkCapacity(int additional) { // Temporarily store the information of a field. We need to collect all fields in an JSON object, // sort them by their keys, and build the variant object in sorted order. - private static final class FieldEntry implements Comparable { + public static final class FieldEntry implements Comparable { final String key; final int id; final int offset; - FieldEntry(String key, int id, int offset) { + public FieldEntry(String key, int id, int offset) { this.key = key; this.id = id; this.offset = offset; @@ -143,117 +388,32 @@ private void buildJson(JsonParser parser) throws IOException { case START_OBJECT: { ArrayList fields = new ArrayList<>(); int start = writePos; - int maxId = 0; while (parser.nextToken() != JsonToken.END_OBJECT) { String key = parser.currentName(); parser.nextToken(); - int id; - if (dictionary.containsKey(key)) { - id = dictionary.get(key); - } else { - id = dictionaryKeys.size(); - dictionary.put(key, id); - dictionaryKeys.add(key.getBytes(StandardCharsets.UTF_8)); - } - maxId = Math.max(maxId, id); - int offset = writePos - start; - fields.add(new FieldEntry(key, id, offset)); + int id = getOrInsertKey(key); + fields.add(new FieldEntry(key, id, writePos - start)); buildJson(parser); } - int dataSize = writePos - start; - int size = fields.size(); - Collections.sort(fields); - // Check for duplicate field keys. Only need to check adjacent key because they are sorted. - for (int i = 1; i < size; ++i) { - String key = fields.get(i - 1).key; - if (key.equals(fields.get(i).key)) { - throw new JsonParseException(parser, "Duplicate key: " + key); - } - } - boolean largeSize = size > U8_MAX; - int sizeBytes = largeSize ? U32_SIZE : 1; - int idSize = getIntegerSize(maxId); - int offsetSize = getIntegerSize(dataSize); - // The space for header byte, object size, id list, and offset list. - int headerSize = 1 + sizeBytes + size * idSize + (size + 1) * offsetSize; - checkCapacity(headerSize); - // Shift the just-written field data to make room for the object header section. - System.arraycopy(writeBuffer, start, writeBuffer, start + headerSize, dataSize); - writePos += headerSize; - writeBuffer[start] = objectHeader(largeSize, idSize, offsetSize); - writeLong(writeBuffer, start + 1, size, sizeBytes); - int idStart = start + 1 + sizeBytes; - int offsetStart = idStart + size * idSize; - for (int i = 0; i < size; ++i) { - writeLong(writeBuffer, idStart + i * idSize, fields.get(i).id, idSize); - writeLong(writeBuffer, offsetStart + i * offsetSize, fields.get(i).offset, offsetSize); - } - writeLong(writeBuffer, offsetStart + size * offsetSize, dataSize, offsetSize); + finishWritingObject(start, fields); break; } case START_ARRAY: { ArrayList offsets = new ArrayList<>(); int start = writePos; while (parser.nextToken() != JsonToken.END_ARRAY) { - int offset = writePos - start; - offsets.add(offset); + offsets.add(writePos - start); buildJson(parser); } - int dataSize = writePos - start; - int size = offsets.size(); - boolean largeSize = size > U8_MAX; - int sizeBytes = largeSize ? U32_SIZE : 1; - int offsetSize = getIntegerSize(dataSize); - // The space for header byte, object size, and offset list. - int headerSize = 1 + sizeBytes + (size + 1) * offsetSize; - checkCapacity(headerSize); - // Shift the just-written field data to make room for the header section. - System.arraycopy(writeBuffer, start, writeBuffer, start + headerSize, dataSize); - writePos += headerSize; - writeBuffer[start] = arrayHeader(largeSize, offsetSize); - writeLong(writeBuffer, start + 1, size, sizeBytes); - int offsetStart = start + 1 + sizeBytes; - for (int i = 0; i < size; ++i) { - writeLong(writeBuffer, offsetStart + i * offsetSize, offsets.get(i), offsetSize); - } - writeLong(writeBuffer, offsetStart + size * offsetSize, dataSize, offsetSize); + finishWritingArray(start, offsets); break; } case VALUE_STRING: - byte[] text = parser.getText().getBytes(StandardCharsets.UTF_8); - boolean longStr = text.length > MAX_SHORT_STR_SIZE; - checkCapacity((longStr ? 1 + U32_SIZE : 1) + text.length); - if (longStr) { - writeBuffer[writePos++] = primitiveHeader(LONG_STR); - writeLong(writeBuffer, writePos, text.length, U32_SIZE); - writePos += U32_SIZE; - } else { - writeBuffer[writePos++] = shortStrHeader(text.length); - } - System.arraycopy(text, 0, writeBuffer, writePos, text.length); - writePos += text.length; + appendString(parser.getText()); break; case VALUE_NUMBER_INT: try { - long l = parser.getLongValue(); - checkCapacity(1 + 8); - if (l == (byte) l) { - writeBuffer[writePos++] = primitiveHeader(INT1); - writeLong(writeBuffer, writePos, l, 1); - writePos += 1; - } else if (l == (short) l) { - writeBuffer[writePos++] = primitiveHeader(INT2); - writeLong(writeBuffer, writePos, l, 2); - writePos += 2; - } else if (l == (int) l) { - writeBuffer[writePos++] = primitiveHeader(INT4); - writeLong(writeBuffer, writePos, l, 4); - writePos += 4; - } else { - writeBuffer[writePos++] = primitiveHeader(INT8); - writeLong(writeBuffer, writePos, l, 8); - writePos += 8; - } + appendLong(parser.getLongValue()); } catch (InputCoercionException ignored) { // If the value doesn't fit any integer type, parse it as decimal or floating instead. parseFloatingPoint(parser); @@ -263,16 +423,13 @@ private void buildJson(JsonParser parser) throws IOException { parseFloatingPoint(parser); break; case VALUE_TRUE: - checkCapacity(1); - writeBuffer[writePos++] = primitiveHeader(TRUE); + appendBoolean(true); break; case VALUE_FALSE: - checkCapacity(1); - writeBuffer[writePos++] = primitiveHeader(FALSE); + appendBoolean(false); break; case VALUE_NULL: - checkCapacity(1); - writeBuffer[writePos++] = primitiveHeader(NULL); + appendNull(); break; default: throw new JsonParseException(parser, "Unexpected token " + token); @@ -290,10 +447,7 @@ private int getIntegerSize(int value) { private void parseFloatingPoint(JsonParser parser) throws IOException { if (!tryParseDecimal(parser.getText())) { - checkCapacity(1 + 8); - writeBuffer[writePos++] = primitiveHeader(DOUBLE); - writeLong(writeBuffer, writePos, Double.doubleToLongBits(parser.getDoubleValue()), 8); - writePos += 8; + appendDouble(parser.getDoubleValue()); } } @@ -308,36 +462,11 @@ private boolean tryParseDecimal(String input) { } } BigDecimal d = new BigDecimal(input); - checkCapacity(2 + 16); - BigInteger unscaled = d.unscaledValue(); - if (d.scale() <= MAX_DECIMAL4_PRECISION && d.precision() <= MAX_DECIMAL4_PRECISION) { - writeBuffer[writePos++] = primitiveHeader(DECIMAL4); - writeBuffer[writePos++] = (byte)d.scale(); - writeLong(writeBuffer, writePos, unscaled.intValueExact(), 4); - writePos += 4; - } else if (d.scale() <= MAX_DECIMAL8_PRECISION && d.precision() <= MAX_DECIMAL8_PRECISION) { - writeBuffer[writePos++] = primitiveHeader(DECIMAL8); - writeBuffer[writePos++] = (byte)d.scale(); - writeLong(writeBuffer, writePos, unscaled.longValueExact(), 8); - writePos += 8; - } else if (d.scale() <= MAX_DECIMAL16_PRECISION && d.precision() <= MAX_DECIMAL16_PRECISION) { - writeBuffer[writePos++] = primitiveHeader(DECIMAL16); - writeBuffer[writePos++] = (byte)d.scale(); - // `toByteArray` returns a big-endian representation. We need to copy it reversely and sign - // extend it to 16 bytes. - byte[] bytes = unscaled.toByteArray(); - for (int i = 0; i < bytes.length; ++i) { - writeBuffer[writePos + i] = bytes[bytes.length - 1 - i]; - } - byte sign = (byte) (bytes[0] < 0 ? -1 : 0); - for (int i = bytes.length; i < 16; ++i) { - writeBuffer[writePos + i] = sign; - } - writePos += 16; - } else { - return false; + if (d.scale() <= MAX_DECIMAL16_PRECISION && d.precision() <= MAX_DECIMAL16_PRECISION) { + appendDecimal(d); + return true; } - return true; + return false; } // The write buffer in building the variant value. Its first `writePos` bytes has been written. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 94cf7130d4852..e252075c9c1c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -128,6 +128,7 @@ object Cast extends QueryErrorsBase { case (TimestampType, _: NumericType) => true case (VariantType, _) => variant.VariantGet.checkDataType(to) + case (_, VariantType) => variant.VariantGet.checkDataType(from) case (ArrayType(fromType, fn), ArrayType(toType, tn)) => canAnsiCast(fromType, toType) && resolvableNullability(fn, tn) @@ -236,6 +237,7 @@ object Cast extends QueryErrorsBase { case (_: NumericType, _: NumericType) => true case (VariantType, _) => variant.VariantGet.checkDataType(to) + case (_, VariantType) => variant.VariantGet.checkDataType(from) case (ArrayType(fromType, fn), ArrayType(toType, tn)) => canCast(fromType, toType) && @@ -1119,6 +1121,7 @@ case class Cast( } else { to match { case dt if dt == from => identity[Any] + case VariantType => input => variant.VariantExpressionEvalUtils.castToVariant(input, from) case _: StringType => castToString(from) case BinaryType => castToBinary(from) case DateType => castToDate(from) @@ -1223,6 +1226,10 @@ case class Cast( $evPrim = (${CodeGenerator.boxedType(to)})$tmp; } """ + case VariantType => + val cls = variant.VariantExpressionEvalUtils.getClass.getName.stripSuffix("$") + val fromArg = ctx.addReferenceObj("from", from) + (c, evPrim, evNull) => code"$evPrim = $cls.castToVariant($c, $fromArg);" case _: StringType => (c, evPrim, _) => castToStringCode(from, ctx).apply(c, evPrim) case BinaryType => castToBinaryCode(from) case DateType => castToDateCode(from, ctx) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala index 74fae91f98a6c..aefd4e33427f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionEvalUtils.scala @@ -19,9 +19,11 @@ package org.apache.spark.sql.catalyst.expressions.variant import scala.util.control.NonFatal -import org.apache.spark.sql.catalyst.util.BadRecordException +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{ArrayData, BadRecordException, MapData} import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.types.variant.{VariantBuilder, VariantSizeLimitException, VariantUtil} +import org.apache.spark.sql.types._ +import org.apache.spark.types.variant.{Variant, VariantBuilder, VariantSizeLimitException, VariantUtil} import org.apache.spark.unsafe.types.{UTF8String, VariantVal} /** @@ -41,4 +43,68 @@ object VariantExpressionEvalUtils { input.toString, BadRecordException(() => input, cause = e)) } } + + def castToVariant(input: Any, dataType: DataType): VariantVal = { + val builder = new VariantBuilder + buildVariant(builder, input, dataType) + val v = builder.result() + new VariantVal(v.getValue, v.getMetadata) + } + + private def buildVariant(builder: VariantBuilder, input: Any, dataType: DataType): Unit = { + if (input == null) { + builder.appendNull() + return + } + dataType match { + case BooleanType => builder.appendBoolean(input.asInstanceOf[Boolean]) + case ByteType => builder.appendLong(input.asInstanceOf[Byte]) + case ShortType => builder.appendLong(input.asInstanceOf[Short]) + case IntegerType => builder.appendLong(input.asInstanceOf[Int]) + case LongType => builder.appendLong(input.asInstanceOf[Long]) + case FloatType => builder.appendFloat(input.asInstanceOf[Float]) + case DoubleType => builder.appendDouble(input.asInstanceOf[Double]) + case StringType => builder.appendString(input.asInstanceOf[UTF8String].toString) + case BinaryType => builder.appendBinary(input.asInstanceOf[Array[Byte]]) + case DateType => builder.appendDate(input.asInstanceOf[Int]) + case TimestampType => builder.appendTimestamp(input.asInstanceOf[Long]) + case TimestampNTZType => builder.appendTimestampNtz(input.asInstanceOf[Long]) + case VariantType => + val v = input.asInstanceOf[VariantVal] + builder.appendVariant(new Variant(v.getValue, v.getMetadata)) + case ArrayType(elementType, _) => + val data = input.asInstanceOf[ArrayData] + val start = builder.getWritePos + val offsets = new java.util.ArrayList[java.lang.Integer](data.numElements()) + for (i <- 0 until data.numElements()) { + offsets.add(builder.getWritePos - start) + buildVariant(builder, data.get(i, elementType), elementType) + } + builder.finishWritingArray(start, offsets) + case MapType(StringType, valueType, _) => + val data = input.asInstanceOf[MapData] + val keys = data.keyArray() + val values = data.valueArray() + val start = builder.getWritePos + val fields = new java.util.ArrayList[VariantBuilder.FieldEntry](data.numElements()) + for (i <- 0 until data.numElements()) { + val key = keys.getUTF8String(i).toString + val id = builder.getOrInsertKey(key) + fields.add(new VariantBuilder.FieldEntry(key, id, builder.getWritePos - start)) + buildVariant(builder, values.get(i, valueType), valueType) + } + builder.finishWritingObject(start, fields) + case StructType(structFields) => + val data = input.asInstanceOf[InternalRow] + val start = builder.getWritePos + val fields = new java.util.ArrayList[VariantBuilder.FieldEntry](structFields.length) + for (i <- 0 until structFields.length) { + val key = structFields(i).name + val id = builder.getOrInsertKey(key) + fields.add(new VariantBuilder.FieldEntry(key, id, builder.getWritePos - start)) + buildVariant(builder, data.get(i, structFields(i).dataType), structFields(i).dataType) + } + builder.finishWritingObject(start, fields) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala index 24675518646d0..dadc285a35f64 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/variant/VariantExpressionSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions.variant import java.time.{LocalDateTime, ZoneId, ZoneOffset} +import scala.reflect.runtime.universe.TypeTag + import org.apache.spark.{SparkFunSuite, SparkRuntimeException} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone @@ -797,4 +799,57 @@ class VariantExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkCast(Array(primitiveHeader(BINARY), 5, 0, 0, 0, 72, 101, 108, 108, 111), StringType, "Hello") } + + private def parseJson(input: String): VariantVal = + VariantExpressionEvalUtils.parseJson(UTF8String.fromString(input)) + + test("cast to variant") { + def check[T : TypeTag](input: T, expectedJson: String): Unit = { + val cast = Cast(Literal.create(input), VariantType, evalMode = EvalMode.ANSI) + checkEvaluation(StructsToJson(Map.empty, cast), expectedJson) + } + + checkEvaluation(Cast(Literal(null, StringType), VariantType, evalMode = EvalMode.ANSI), null) + for (input <- Seq[Any](false, true, 0.toByte, 1.toShort, 2, 3L, 4.0F, 5.0D)) { + check(input, input.toString) + } + check(Array(null, "a", "b", "c"), """[null,"a","b","c"]""") + check(Map("z" -> 1, "y" -> 2, "x" -> 3), """{"x":3,"y":2,"z":1}""") + check(Array(parseJson("""{"a": 1,"b": [1, 2, 3]}"""), + parseJson("""{"c": true,"d": {"e": "str"}}""")), + """[{"a":1,"b":[1,2,3]},{"c":true,"d":{"e":"str"}}]""") + } + + import java.time.format._ + import java.time._ +// import java.time.temporal.ChronoUnit + + import org.apache.spark.sql.catalyst.util._ + + def getTimestamp(timestamp: Long): Instant = + SparkDateTimeUtils.microsToInstant(timestamp) +// Instant.EPOCH.plus(timestamp, ChronoUnit.MICROS) + + test("timezone test") { + val utc = ZoneId.of("UTC") + val la = ZoneId.of("America/Los_Angeles") + val fmt = new DateTimeFormatterBuilder() + .appendPattern("yyyy-MM-dd HH:mm:ss xxx") +// .appendPattern(TimestampFormatter.defaultPattern()) + .toFormatter(java.util.Locale.US) + for (t <- Seq(Long.MinValue, -70000000000000000L, -60000000000000000L, 0L, Long.MaxValue)) { + // scalastyle:off println + println(DateTimeFormatter.ISO_OFFSET_DATE_TIME.format(getTimestamp(t).atZone(utc))) + println(DateTimeFormatter.ISO_OFFSET_DATE_TIME.format(getTimestamp(t).atZone(la))) + +// println(fmt.format(getTimestamp(t).atZone(utc))) +// println(fmt.format(getTimestamp(t).atZone(la))) +// +// println(new FractionTimestampFormatter(utc).format(t)) +// println(new FractionTimestampFormatter(la).format(t)) + +// println(fmt.withZone(utc).format(getTimestamp(t))) +// println(fmt.withZone(la).format(getTimestamp(t))) + } + } }