Skip to content

Commit

Permalink
initial
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhao-db committed Aug 28, 2024
1 parent a3cb064 commit d46339b
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@
import java.math.BigDecimal;
import java.math.BigInteger;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.*;

import com.fasterxml.jackson.core.JsonFactory;
import com.fasterxml.jackson.core.JsonParser;
Expand All @@ -43,24 +40,29 @@
* Build variant value and metadata by parsing JSON values.
*/
public class VariantBuilder {
public VariantBuilder(boolean allowDuplicateKeys) {
this.allowDuplicateKeys = allowDuplicateKeys;
}

/**
* Parse a JSON string as a Variant value.
* @throws VariantSizeLimitException if the resulting variant value or metadata would exceed
* the SIZE_LIMIT (for example, this could be a maximum of 16 MiB).
* @throws IOException if any JSON parsing error happens.
*/
public static Variant parseJson(String json) throws IOException {
public static Variant parseJson(String json, boolean allowDuplicateKeys) throws IOException {
try (JsonParser parser = new JsonFactory().createParser(json)) {
parser.nextToken();
return parseJson(parser);
return parseJson(parser, allowDuplicateKeys);
}
}

/**
* Similar {@link #parseJson(String)}, but takes a JSON parser instead of string input.
* Similar {@link #parseJson(String, boolean)}, but takes a JSON parser instead of string input.
*/
public static Variant parseJson(JsonParser parser) throws IOException {
VariantBuilder builder = new VariantBuilder();
public static Variant parseJson(JsonParser parser, boolean allowDuplicateKeys)
throws IOException {
VariantBuilder builder = new VariantBuilder(allowDuplicateKeys);
builder.buildJson(parser);
return builder.result();
}
Expand Down Expand Up @@ -274,23 +276,61 @@ public int getWritePos() {
// record the offset of the field. The offset is computed as `getWritePos() - start`.
// 3. The caller calls `finishWritingObject` to finish writing a variant object.
//
// This function is responsible to sort the fields by key and check for any duplicate field keys.
// This function is responsible to sort the fields by key. If there are duplicate field keys,
// the field with the greatest offset value (the last appended one) is kept.
public void finishWritingObject(int start, ArrayList<FieldEntry> fields) {
int dataSize = writePos - start;
int size = fields.size();
Collections.sort(fields);
int maxId = size == 0 ? 0 : fields.get(0).id;
// Check for duplicate field keys. Only need to check adjacent key because they are sorted.
for (int i = 1; i < size; ++i) {
maxId = Math.max(maxId, fields.get(i).id);
String key = fields.get(i).key;
if (key.equals(fields.get(i - 1).key)) {
@SuppressWarnings("unchecked")
Map<String, String> parameters = Map$.MODULE$.<String, String>empty().updated("key", key);
throw new SparkRuntimeException("VARIANT_DUPLICATE_KEY", parameters,
null, new QueryContext[]{}, "");
if (allowDuplicateKeys) {
int distinctPos = 0;
// Maintain a list of distinct keys in-place.
for (int i = 1; i < size; ++i) {
maxId = Math.max(maxId, fields.get(i).id);
if (fields.get(i).id == fields.get(i - 1).id) {
// Found a duplicate key. Keep the field with a greater offset.
if (fields.get(distinctPos).offset < fields.get(i).offset) {
fields.set(distinctPos, fields.get(distinctPos).withNewOffset(fields.get(i).offset));
}
} else {
// Found a distinct key. Add the field to the list.
++distinctPos;
fields.set(distinctPos, fields.get(i));
}
}
if (distinctPos + 1 < fields.size()) {
size = distinctPos + 1;
// Resize `fields` to `size`.
fields.subList(size, fields.size()).clear();
// Sort the fields by offsets so that we can move the value data of each field to the new
// offset without overwriting the fields after it.
fields.sort(Comparator.comparingInt(f -> f.offset));
int currentOffset = 0;
for (int i = 0; i < size; ++i) {
int oldOffset = fields.get(i).offset;
int fieldSize = VariantUtil.valueSize(writeBuffer, start + oldOffset);
System.arraycopy(writeBuffer, start + oldOffset,
writeBuffer, start + currentOffset, fieldSize);
fields.set(i, fields.get(i).withNewOffset(currentOffset));
currentOffset += fieldSize;
}
writePos = start + currentOffset;
// Change back to the sort order by field keys to meet the variant spec.
Collections.sort(fields);
}
} else {
for (int i = 1; i < size; ++i) {
maxId = Math.max(maxId, fields.get(i).id);
String key = fields.get(i).key;
if (key.equals(fields.get(i - 1).key)) {
@SuppressWarnings("unchecked")
Map<String, String> parameters = Map$.MODULE$.<String, String>empty().updated("key", key);
throw new SparkRuntimeException("VARIANT_DUPLICATE_KEY", parameters,
null, new QueryContext[]{}, "");
}
}
}
int dataSize = writePos - start;
boolean largeSize = size > U8_MAX;
int sizeBytes = largeSize ? U32_SIZE : 1;
int idSize = getIntegerSize(maxId);
Expand Down Expand Up @@ -415,6 +455,10 @@ public FieldEntry(String key, int id, int offset) {
this.offset = offset;
}

FieldEntry withNewOffset(int newOffset) {
return new FieldEntry(key, id, newOffset);
}

@Override
public int compareTo(FieldEntry other) {
return key.compareTo(other.key);
Expand Down Expand Up @@ -518,4 +562,5 @@ private boolean tryParseDecimal(String input) {
private final HashMap<String, Integer> dictionary = new HashMap<>();
// Store all keys in `dictionary` in the order of id.
private final ArrayList<byte[]> dictionaryKeys = new ArrayList<>();
private final boolean allowDuplicateKeys;
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ import org.apache.spark.unsafe.types.{UTF8String, VariantVal}
*/
object VariantExpressionEvalUtils {

def parseJson(input: UTF8String, failOnError: Boolean = true): VariantVal = {
def parseJson(
input: UTF8String,
allowDuplicateKeys: Boolean = true,
failOnError: Boolean = true): VariantVal = {
def parseJsonFailure(exception: Throwable): VariantVal = {
if (failOnError) {
throw exception
Expand All @@ -40,7 +43,7 @@ object VariantExpressionEvalUtils {
}
}
try {
val v = VariantBuilder.parseJson(input.toString)
val v = VariantBuilder.parseJson(input.toString, allowDuplicateKeys)
new VariantVal(v.getValue, v.getMetadata)
} catch {
case _: VariantSizeLimitException =>
Expand Down Expand Up @@ -69,7 +72,8 @@ object VariantExpressionEvalUtils {

/** Cast a Spark value from `dataType` into the variant type. */
def castToVariant(input: Any, dataType: DataType): VariantVal = {
val builder = new VariantBuilder
// illegal.
val builder = new VariantBuilder(false)
buildVariant(builder, input, dataType)
val v = builder.result()
new VariantVal(v.getValue, v.getMetadata)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,11 @@ case class ParseJson(child: Expression, failOnError: Boolean = true)
VariantExpressionEvalUtils.getClass,
VariantType,
"parseJson",
Seq(child, Literal(failOnError, BooleanType)),
inputTypes :+ BooleanType,
Seq(
child,
Literal(SQLConf.get.getConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS), BooleanType),
Literal(failOnError, BooleanType)),
inputTypes :+ BooleanType :+ BooleanType,
returnNullable = !failOnError)

override def inputTypes: Seq[AbstractDataType] = StringTypeAnyCollation :: Nil
Expand Down Expand Up @@ -324,7 +327,7 @@ case object VariantGet {

if (dataType == VariantType) {
// Build a new variant, in order to strip off any unnecessary metadata.
val builder = new VariantBuilder
val builder = new VariantBuilder(false)
builder.appendVariant(v)
val result = builder.result()
return new VariantVal(result.getValue, result.getMetadata)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ class JacksonParser(
}
}

private val variantAllowDuplicateKeys = SQLConf.get.getConf(SQLConf.VARIANT_ALLOW_DUPLICATE_KEYS)

protected final def parseVariant(parser: JsonParser): VariantVal = {
// Skips `FIELD_NAME` at the beginning. This check is adapted from `parseJsonToken`, but we
// cannot directly use the function here because it also handles the `VALUE_NULL` token and
Expand All @@ -125,7 +127,7 @@ class JacksonParser(
parser.nextToken()
}
try {
val v = VariantBuilder.parseJson(parser)
val v = VariantBuilder.parseJson(parser, variantAllowDuplicateKeys)
new VariantVal(v.getValue, v.getMetadata)
} catch {
case _: VariantSizeLimitException =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4412,6 +4412,16 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val VARIANT_ALLOW_DUPLICATE_KEYS =
buildConf("spark.sql.variant.allowDuplicateKeys")
.internal()
.doc("When set to false, parsing variant from JSON will throw an error if there are " +
"duplicate keys in the input JSON object. When set to true, the parser will keep the " +
"last occurrence of all fields with the same key.")
.version("4.0.0")
.booleanConf
.createWithDefault(true)

val LEGACY_CSV_ENABLE_DATE_TIME_PARSING_FALLBACK =
buildConf("spark.sql.legacy.csv.enableDateTimeParsingFallback")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,12 @@ class VariantExpressionEvalUtilsSuite extends SparkFunSuite {
/* offset list */ 0, 2, 4, 6,
/* field data */ primitiveHeader(INT1), 1, primitiveHeader(INT1), 2, shortStrHeader(1), '3'),
Array(VERSION, 3, 0, 1, 2, 3, 'a', 'b', 'c'))
check("""{"a": 1, "b": 2, "c": "3", "a": 4}""", Array(objectHeader(false, 1, 1),
/* size */ 3,
/* id list */ 0, 1, 2,
/* offset list */ 4, 0, 2, 6,
/* field data */ primitiveHeader(INT1), 2, shortStrHeader(1), '3', primitiveHeader(INT1), 4),
Array(VERSION, 3, 0, 1, 2, 3, 'a', 'b', 'c'))
check("""{"z": 1, "y": 2, "x": "3"}""", Array(objectHeader(false, 1, 1),
/* size */ 3,
/* id list */ 2, 1, 0,
Expand Down Expand Up @@ -119,7 +125,7 @@ class VariantExpressionEvalUtilsSuite extends SparkFunSuite {
)
assert(try_parse_json_output === null)
}
for (json <- Seq("", "[", "+1", "1a", """{"a": 1, "b": 2, "a": "3"}""")) {
for (json <- Seq("", "[", "+1", "1a")) {
checkException(json, "MALFORMED_RECORD_IN_PARSING.WITHOUT_SUGGESTION",
Map("badRecord" -> json, "failFastMode" -> "FAILFAST"))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession {
)
// scalastyle:on nonascii
check("[0.0, 1.00, 1.10, 1.23]", "[0,1,1.1,1.23]")
check(
"""{"c": [], "b": 0, "a": null, "a": {"x": 0, "x": 1}, "b": 1, "b": 2, "c": [3]}""",
"""{"a":{"x":1},"b":2,"c":[3]}"""
)
}

test("from_json/to_json round-trip") {
Expand Down Expand Up @@ -146,7 +150,7 @@ class VariantEndToEndSuite extends QueryTest with SharedSparkSession {
val variantDF = df.select(parse_json(col("v")))
val plan = variantDF.queryExecution.executedPlan
assert(plan.isInstanceOf[WholeStageCodegenExec])
val v = VariantBuilder.parseJson("""{"a":1}""")
val v = VariantBuilder.parseJson("""{"a":1}""", false)
val expected = new VariantVal(v.getValue, v.getMetadata)
checkAnswer(variantDF, Seq(Row(expected)))
}
Expand Down

0 comments on commit d46339b

Please sign in to comment.