Skip to content

Commit

Permalink
[SPARK-50815][PYTHON][SQL] Fix bug where passing null Variants in cre…
Browse files Browse the repository at this point in the history
…ateDataFrame causes it to fail and add Variant support in createDataFrame in Spark Connect

### What changes were proposed in this pull request?

In this PR, we add a case to handle None in `VariantType.toInternal`. Also, variants can be used with `createDataFrame` when using Spark Connect.

### Why are the changes needed?

Previously, `spark.createDataFrame([(VariantVal(bytearray([12, 1]), bytearray([1, 0, 0])),), (None,)], "v variant").show()` failed because there was no way of handling nulls.

Also, `createDataFrame` did not work with Variants prior to this PR - now it does.

### Does this PR introduce _any_ user-facing change?

Yes, it fixes a bug where `None` values couldn't be handled with Variant schemas, and allows users to use createDataFrame with Variants in the Python client.

### How was this patch tested?

Unit test

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#49487 from harshmotw-db/harsh-motwani_data/variant_bug_fix.

Authored-by: Harsh Motwani <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
harshmotw-db authored and cloud-fan committed Jan 15, 2025
1 parent 21a37a7 commit 39bb2d8
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 8 deletions.
16 changes: 8 additions & 8 deletions python/pyspark/sql/connect/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ def _need_converter(
def _create_converter(
dataType: DataType,
nullable: bool = True,
variants_as_dicts: bool = False, # some code paths may require python internal types
) -> Callable:
assert dataType is not None and isinstance(dataType, DataType)
assert isinstance(nullable, bool)
Expand All @@ -126,8 +127,7 @@ def convert_null(value: Any) -> Any:

field_convs = [
LocalDataToArrowConversion._create_converter(
field.dataType,
field.nullable,
field.dataType, field.nullable, variants_as_dicts
)
for field in dataType.fields
]
Expand Down Expand Up @@ -170,8 +170,7 @@ def convert_struct(value: Any) -> Any:

elif isinstance(dataType, ArrayType):
element_conv = LocalDataToArrowConversion._create_converter(
dataType.elementType,
dataType.containsNull,
dataType.elementType, dataType.containsNull, variants_as_dicts
)

def convert_array(value: Any) -> Any:
Expand All @@ -188,8 +187,7 @@ def convert_array(value: Any) -> Any:
elif isinstance(dataType, MapType):
key_conv = LocalDataToArrowConversion._create_converter(dataType.keyType)
value_conv = LocalDataToArrowConversion._create_converter(
dataType.valueType,
dataType.valueContainsNull,
dataType.valueType, dataType.valueContainsNull, variants_as_dicts
)

def convert_map(value: Any) -> Any:
Expand Down Expand Up @@ -303,8 +301,11 @@ def convert_variant(value: Any) -> Any:
isinstance(value, dict)
and all(key in value for key in ["value", "metadata"])
and all(isinstance(value[key], bytes) for key in ["value", "metadata"])
and not variants_as_dicts
):
return VariantVal(value["value"], value["metadata"])
elif isinstance(value, VariantVal) and variants_as_dicts:
return VariantType().toInternal(value)
else:
raise PySparkValueError(errorClass="MALFORMED_VARIANT")

Expand All @@ -331,8 +332,7 @@ def convert(data: Sequence[Any], schema: StructType) -> "pa.Table":

column_convs = [
LocalDataToArrowConversion._create_converter(
field.dataType,
field.nullable,
field.dataType, field.nullable, variants_as_dicts=True
)
for field in schema.fields
]
Expand Down
27 changes: 27 additions & 0 deletions python/pyspark/sql/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2251,6 +2251,33 @@ def test_variant_type(self):
self.assertEqual(parse_json_spark_output.value, parse_json_python_output.value)
self.assertEqual(parse_json_spark_output.metadata, parse_json_python_output.metadata)

# Test createDataFrame
create_df_variants = self.spark.createDataFrame(
[
(
VariantVal.parseJson("2"),
[VariantVal.parseJson("3")],
{"v": VariantVal.parseJson("4")},
{"v": VariantVal.parseJson("5")},
),
(None, [None], {"v": None}, {"v": None}),
(None, None, None, None),
],
"v variant, a array<variant>, s struct<v variant>, m map<string, variant>",
).collect()
self.assertEqual(create_df_variants[0][0].toJson(), "2")
self.assertEqual(create_df_variants[0][1][0].toJson(), "3")
self.assertEqual(create_df_variants[0][2][0].toJson(), "4")
self.assertEqual(create_df_variants[0][3]["v"].toJson(), "5")
self.assertEqual(create_df_variants[1][0], None)
self.assertEqual(create_df_variants[1][1][0], None)
self.assertEqual(create_df_variants[1][2][0], None)
self.assertEqual(create_df_variants[1][3]["v"], None)
self.assertEqual(create_df_variants[2][0], None)
self.assertEqual(create_df_variants[2][1], None)
self.assertEqual(create_df_variants[2][2], None)
self.assertEqual(create_df_variants[2][3], None)

def test_to_ddl(self):
schema = StructType().add("a", NullType()).add("b", BooleanType()).add("c", BinaryType())
self.assertEqual(schema.toDDL(), "a VOID,b BOOLEAN,c BINARY")
Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1580,6 +1580,8 @@ def fromInternal(self, obj: Dict) -> Optional["VariantVal"]:
return VariantVal(obj["value"], obj["metadata"])

def toInternal(self, variant: Any) -> Any:
if variant is None:
return None
assert isinstance(variant, VariantVal)
return {"value": variant.value, "metadata": variant.metadata}

Expand Down

0 comments on commit 39bb2d8

Please sign in to comment.