Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Created the functionality to split the columns for issue #85 #92

Merged
merged 13 commits into from
Oct 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion quinn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

"""quinn API."""

from quinn.append_if_schema_identical import append_if_schema_identical
from quinn.split_columns import split_col
from quinn.dataframe_helpers import (
column_to_list,
create_df,
Expand Down
2 changes: 1 addition & 1 deletion quinn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def forall(f: Callable[[Any], bool]) -> udf:
:return: A spark UDF which accepts a list of arguments and returns True if all
elements pass through the given boolean function, False otherwise.
:rtype: UserDefinedFunction
"""
"""

def temp_udf(list_: list) -> bool:
return all(map(f, list_))
Expand Down
108 changes: 108 additions & 0 deletions quinn/split_columns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from typing import List
from pyspark.sql import DataFrame
from pyspark.sql.functions import split, when, length, trim, udf
from pyspark.sql.types import IntegerType


def split_col(df: DataFrame, col_name: str, delimiter: str,
new_col_names: List[str], mode: str = "strict", default: str = "default") -> DataFrame:
"""
Splits the given column based on the delimiter and creates new columns with the split values.
:param df: The input DataFrame
:type df: pyspark.sql.DataFrame
:param col_name: The name of the column to split
:type col_name: str
:param delimiter: The delimiter to split the column on
:type delimiter: str
:param new_col_names: A list of two strings for the new column names
:type new_col_names: (List[str])
:param mode: The split mode. Can be "strict" or "permissive". Default is "strict"
puneetsharma04 marked this conversation as resolved.
Show resolved Hide resolved
:type mode: str
:param default: If the mode is "permissive" then default value will be assigned to column
:type mode: str
:return: dataframe: The resulting DataFrame with the split columns
:rtype: pyspark.sql.DataFrame
"""
# Check if the column to be split exists in the DataFrame
if col_name not in df.columns:
raise ValueError(f"Column '{col_name}' not found in DataFrame.")

# Check if the delimiter is a string
if not isinstance(delimiter, str):
raise TypeError("Delimiter must be a string.")

# Check if the new column names are a list of strings
if not isinstance(new_col_names, list):
raise ValueError("New column names must be a list of strings.")

# Define a UDF to check the occurrence of delimitter
def num_delimiter(col_value1):
# Get the count of delimiter and store the result in no_of_delimiter
no_of_delimiter = col_value1.count(delimiter)
# Split col_value based on delimiter and store the result in split_value
split_value = col_value1.split(delimiter)

# Check if col_value is not None
if col_value1 is not None:

# Check if the no of delimiters in split_value is not as expected
if no_of_delimiter != len(new_col_names) - 1:

# If the length is not same, raise an IndexError with the message mentioning the expected and found length
raise IndexError(
f"Expected {len(new_col_names)} elements after splitting on delimiter, found {len(split_value)} elements")

# If the length of split_value is same as new_col_names, check if any of the split values is None or empty string
elif any(map(lambda x: x is None or x.strip() == '', split_value[:len(new_col_names)])):
raise ValueError("Null or empty values are not accepted for columns in strict mode")

# If the above checks pass, return the count of delimiter
return int(no_of_delimiter)

# If col_value is None, return 0
return 0

num_udf = udf(lambda y: num_delimiter(y), IntegerType())

# Get the column expression for the column to be split
col_expr = df[col_name]

# Split the column by the delimiter
split_col_expr = split(trim(col_expr), delimiter)

# Check the split mode
if mode == "strict":

# Create an array of select expressions to create new columns from the split values
select_exprs = [when(split_col_expr.getItem(i) != '', split_col_expr.getItem(i))
.alias(new_col_names[i])
for i in range(len(new_col_names))]

# Select all the columns from the input DataFrame, along with the new split columns
df = df.select("*", *select_exprs)
df = df.withColumn("del_length", num_udf(df[col_name]))
df.cache()
# Drop the original column if the new columns were created successfully
df = df.select([c for c in df.columns if c not in {"del_length", col_name}])
# df = df.select([c for c in df.columns if c != "col_length"])
# return (df.select([c for c in df.columns if c not in {"col_length",col_name}]))

elif mode == "permissive":

# Create an array of select expressions to create new columns from the split values
# Use the default value if a split value is missing or empty
select_exprs = select_exprs = [when(length(split_col_expr.getItem(i)) > 0, split_col_expr.getItem(i)
).otherwise(default).alias(new_col_names[i])
for i in range(len(new_col_names))]

# Select all the columns from the input DataFrame, along with the new split columns
# Drop the original column if the new columns were created successfully
df = df.select("*", *select_exprs).drop(col_name)
df.cache()

else:

raise ValueError(f"Invalid mode: {mode}")

# Return the DataFrame with the split columns
return df
39 changes: 39 additions & 0 deletions tests/test_split_columns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import quinn
from tests.conftest import auto_inject_fixtures


@auto_inject_fixtures("spark")
def test_split_columns(spark):
# Create Spark DataFrame
data = [("chrisXXmoe", 2025, "bio"),
("davidXXbb", 2026, "physics"),
("sophiaXXraul", 2022, "bio"),
("fredXXli", 2025, "physics"),
("someXXperson", 2023, "math"),
("liXXyao", 2025, "physics")]

df = spark.createDataFrame(data, ["student_name", "graduation_year", "major"])
# Define the delimiter
delimiter = "XX"

# New column names
new_col_names = ["student_first_name", "student_last_name"]

col_name = "student_name"
mode = "strict"
# Call split_col() function to split "student_name" column
new_df = quinn.split_col(df, col_name, delimiter, new_col_names, mode)

# Show the resulting DataFrame
new_df.show()

# Verify the resulting DataFrame has the expected columns and values
assert set(new_df.columns) == set(["graduation_year", "major", "student_first_name", "student_last_name"])
assert new_df.count() == 6
assert new_df.filter("student_first_name = 'chris'").count() == 1
assert new_df.filter("student_last_name = 'moe'").count() == 1

col_name1 = "non_existent_column"
# Verify that a ValueError is raised when calling split_col() with a non-existent column name
assert quinn.split_col(df, col_name1, delimiter, new_col_names, mode) is not None, ValueError("Error: split_col "
"returned None")