Skip to content

Commit

Permalink
[SPARK-48278][PYTHON][CONNECT] Refine the string representation of `C…
Browse files Browse the repository at this point in the history
…ast`

### What changes were proposed in this pull request?
Refine the string representation of `Cast`

### Why are the changes needed?
try the best to make the string representation consistent with Spark Classic

### Does this PR introduce _any_ user-facing change?
Spark Classic:
```
In [1]: from pyspark.sql import functions as sf

In [2]: sf.col("a").try_cast("int")
Out[2]: Column<'TRY_CAST(a AS INT)'>
```

Spark Connect, before this PR:
```
In [1]: from pyspark.sql import functions as sf

In [2]: sf.col("a").try_cast("int")
Out[2]: Column<'(a (int))'>
```

Spark Connect, after this PR:
```
In [1]: from pyspark.sql import functions as sf

In [2]: sf.col("a").try_cast("int")
Out[2]: Column<'TRY_CAST(a AS INT)'>
```

### How was this patch tested?
added tests

### Was this patch authored or co-authored using generative AI tooling?
no

Closes apache#46585 from zhengruifeng/cast_str_repr.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
  • Loading branch information
zhengruifeng committed May 15, 2024
1 parent c03ebb4 commit e97ad0a
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
14 changes: 13 additions & 1 deletion python/pyspark/sql/connect/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,7 @@ def __init__(
) -> None:
super().__init__()
self._expr = expr
assert isinstance(data_type, (DataType, str))
self._data_type = data_type
if eval_mode is not None:
assert isinstance(eval_mode, str)
Expand All @@ -873,7 +874,18 @@ def to_plan(self, session: "SparkConnectClient") -> proto.Expression:
return fun

def __repr__(self) -> str:
return f"({self._expr} ({self._data_type}))"
# We cannot guarantee the string representations be exactly the same, e.g.
# str(sf.col("a").cast("long")):
# Column<'CAST(a AS BIGINT)'> <- Spark Classic
# Column<'CAST(a AS LONG)'> <- Spark Connect
if isinstance(self._data_type, DataType):
str_data_type = self._data_type.simpleString().upper()
else:
str_data_type = str(self._data_type).upper()
if self._eval_mode is not None and self._eval_mode == "try":
return f"TRY_CAST({self._expr} AS {str_data_type})"
else:
return f"CAST({self._expr} AS {str_data_type})"


class UnresolvedNamedLambdaVariable(Expression):
Expand Down
13 changes: 12 additions & 1 deletion python/pyspark/sql/tests/test_column.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from itertools import chain
from pyspark.sql import Column, Row
from pyspark.sql import functions as sf
from pyspark.sql.types import StructType, StructField, LongType
from pyspark.sql.types import StructType, StructField, IntegerType, LongType
from pyspark.errors import AnalysisException, PySparkTypeError, PySparkValueError
from pyspark.testing.sqlutils import ReusedSQLTestCase

Expand Down Expand Up @@ -228,6 +228,17 @@ def test_alias_negative(self):
message_parameters={"arg_name": "metadata"},
)

def test_cast_str_representation(self):
self.assertEqual(str(sf.col("a").cast("int")), "Column<'CAST(a AS INT)'>")
self.assertEqual(str(sf.col("a").cast("INT")), "Column<'CAST(a AS INT)'>")
self.assertEqual(str(sf.col("a").cast(IntegerType())), "Column<'CAST(a AS INT)'>")
self.assertEqual(str(sf.col("a").cast(LongType())), "Column<'CAST(a AS BIGINT)'>")

self.assertEqual(str(sf.col("a").try_cast("int")), "Column<'TRY_CAST(a AS INT)'>")
self.assertEqual(str(sf.col("a").try_cast("INT")), "Column<'TRY_CAST(a AS INT)'>")
self.assertEqual(str(sf.col("a").try_cast(IntegerType())), "Column<'TRY_CAST(a AS INT)'>")
self.assertEqual(str(sf.col("a").try_cast(LongType())), "Column<'TRY_CAST(a AS BIGINT)'>")

def test_cast_negative(self):
with self.assertRaises(PySparkTypeError) as pe:
self.spark.range(1).id.cast(123)
Expand Down

0 comments on commit e97ad0a

Please sign in to comment.