diff --git a/butterfree/constants/data_type.py b/butterfree/constants/data_type.py index e99525f7..6166f1fc 100644 --- a/butterfree/constants/data_type.py +++ b/butterfree/constants/data_type.py @@ -12,6 +12,7 @@ IntegerType, LongType, StringType, + TimestampNTZType, TimestampType, ) from typing_extensions import final @@ -21,6 +22,7 @@ class DataType(Enum): """Holds constants for data types within Butterfree.""" + TIMESTAMP_NTZ = (TimestampNTZType(), "timestamp", "TIMESTAMP_NTZ") TIMESTAMP = (TimestampType(), "timestamp", "TIMESTAMP") BINARY = (BinaryType(), "boolean", "BINARY") BOOLEAN = (BooleanType(), "boolean", "BOOLEAN") diff --git a/butterfree/transform/features/timestamp_feature.py b/butterfree/transform/features/timestamp_feature.py index aa30dfc4..b4aee71e 100644 --- a/butterfree/transform/features/timestamp_feature.py +++ b/butterfree/transform/features/timestamp_feature.py @@ -41,6 +41,7 @@ class TimestampFeature(Feature): def __init__( self, + dtype: Optional[DataType] = DataType.TIMESTAMP, from_column: Optional[str] = None, transformation: Optional[TransformComponent] = None, from_ms: bool = False, @@ -51,7 +52,7 @@ def __init__( name=TIMESTAMP_COLUMN, description=description, from_column=from_column, - dtype=DataType.TIMESTAMP, + dtype=dtype, transformation=transformation, ) self.from_ms = from_ms diff --git a/tests/integration/butterfree/pipelines/test_feature_set_pipeline.py b/tests/integration/butterfree/pipelines/test_feature_set_pipeline.py index 79125339..16eb08e2 100644 --- a/tests/integration/butterfree/pipelines/test_feature_set_pipeline.py +++ b/tests/integration/butterfree/pipelines/test_feature_set_pipeline.py @@ -50,7 +50,7 @@ def create_temp_view(dataframe: DataFrame, name): def create_db_and_table(spark, table_reader_id, table_reader_db, table_reader_table): - spark.sql(f"drop schema {table_reader_db} cascade") + spark.sql(f"drop schema if exists {table_reader_db} cascade") spark.sql(f"create database {table_reader_db}") spark.sql(f"use {table_reader_db}") spark.sql( diff --git a/tests/unit/butterfree/transform/features/test_timestamp_feature.py b/tests/unit/butterfree/transform/features/test_timestamp_feature.py index a5a688c2..a7befd90 100644 --- a/tests/unit/butterfree/transform/features/test_timestamp_feature.py +++ b/tests/unit/butterfree/transform/features/test_timestamp_feature.py @@ -1,18 +1,25 @@ -from pyspark.sql.types import StringType +from datetime import datetime +from pyspark.sql.types import StringType, StructField, StructType + +from butterfree.clients import SparkClient from butterfree.constants import DataType from butterfree.constants.columns import TIMESTAMP_COLUMN from butterfree.transform.features import TimestampFeature +# from pyspark.sql.types import * + class TestTimestampFeature: def test_args_without_transformation(self): test_key = TimestampFeature(from_column="ts") + test_key_ntz = TimestampFeature(dtype=DataType.TIMESTAMP_NTZ, from_column="ts") assert test_key.name == TIMESTAMP_COLUMN assert test_key.from_column == "ts" assert test_key.dtype == DataType.TIMESTAMP + assert test_key_ntz.dtype == DataType.TIMESTAMP_NTZ def test_transform(self, feature_set_dataframe): @@ -70,3 +77,43 @@ def test_transform_mask(self, feature_set_dataframe_date): assert df[0]["timestamp"] == "2020-02-07 00:00:00" assert df[1]["timestamp"] == "2020-02-08 00:00:00" + + def test_no_timezone(self): + + spark = SparkClient() + now = datetime.now() + + # Testing a new timezone + spark.conn.conf.set("spark.sql.session.timeZone", "GMT-5") + + l = [(now, now)] + rdd = spark.conn.sparkContext.parallelize(l) + + schema = StructType( + [ + StructField("ts", DataType.TIMESTAMP.spark, True), + StructField("ts_ntz", DataType.TIMESTAMP_NTZ.spark, True), + ] + ) + df = spark.conn.createDataFrame(rdd, schema) + df.createOrReplaceTempView("temp_tz_table") + + df1 = spark.conn.sql("""SELECT ts, ts_ntz FROM temp_tz_table""") + df2 = df1.withColumns( + {"ts": df1.ts.cast(StringType()), "ts_ntz": df1.ts_ntz.cast(StringType())} + ) + df2_vals = df2.collect()[0] + + assert df2_vals.ts != df2_vals.ts_ntz + + # New timezone, the column with TZ must have a different value, and the column NTZ must keep its value + spark.conn.conf.set("spark.sql.session.timeZone", "GMT-7") + + df3 = spark.conn.sql("""SELECT ts, ts_ntz FROM temp_tz_table""") + df4 = df3.withColumns( + {"ts": df1.ts.cast(StringType()), "ts_ntz": df1.ts_ntz.cast(StringType())} + ) + df4_vals = df4.collect()[0] + + assert df4_vals.ts != df2_vals.ts + assert df4_vals.ts_ntz == df2_vals.ts_ntz