Skip to content

Commit

Permalink
feat: NTZ
Browse files Browse the repository at this point in the history
  • Loading branch information
ralphrass committed Jun 12, 2024
1 parent cbda73d commit d6998d0
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 3 deletions.
2 changes: 2 additions & 0 deletions butterfree/constants/data_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
IntegerType,
LongType,
StringType,
TimestampNTZType,
TimestampType,
)
from typing_extensions import final
Expand All @@ -21,6 +22,7 @@
class DataType(Enum):
"""Holds constants for data types within Butterfree."""

TIMESTAMP_NTZ = (TimestampNTZType(), "timestamp", "TIMESTAMP_NTZ")
TIMESTAMP = (TimestampType(), "timestamp", "TIMESTAMP")
BINARY = (BinaryType(), "boolean", "BINARY")
BOOLEAN = (BooleanType(), "boolean", "BOOLEAN")
Expand Down
3 changes: 2 additions & 1 deletion butterfree/transform/features/timestamp_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class TimestampFeature(Feature):

def __init__(
self,
dtype: Optional[DataType] = DataType.TIMESTAMP,
from_column: Optional[str] = None,
transformation: Optional[TransformComponent] = None,
from_ms: bool = False,
Expand All @@ -51,7 +52,7 @@ def __init__(
name=TIMESTAMP_COLUMN,
description=description,
from_column=from_column,
dtype=DataType.TIMESTAMP,
dtype=dtype,
transformation=transformation,
)
self.from_ms = from_ms
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def create_temp_view(dataframe: DataFrame, name):


def create_db_and_table(spark, table_reader_id, table_reader_db, table_reader_table):
spark.sql(f"drop schema {table_reader_db} cascade")
spark.sql(f"drop schema if exists {table_reader_db} cascade")
spark.sql(f"create database {table_reader_db}")
spark.sql(f"use {table_reader_db}")
spark.sql(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
from pyspark.sql.types import StringType
from datetime import datetime

from pyspark.sql.types import StringType, StructField, StructType

from butterfree.clients import SparkClient
from butterfree.constants import DataType
from butterfree.constants.columns import TIMESTAMP_COLUMN
from butterfree.transform.features import TimestampFeature

# from pyspark.sql.types import *


class TestTimestampFeature:
def test_args_without_transformation(self):

test_key = TimestampFeature(from_column="ts")
test_key_ntz = TimestampFeature(dtype=DataType.TIMESTAMP_NTZ, from_column="ts")

assert test_key.name == TIMESTAMP_COLUMN
assert test_key.from_column == "ts"
assert test_key.dtype == DataType.TIMESTAMP
assert test_key_ntz.dtype == DataType.TIMESTAMP_NTZ

def test_transform(self, feature_set_dataframe):

Expand Down Expand Up @@ -70,3 +77,43 @@ def test_transform_mask(self, feature_set_dataframe_date):

assert df[0]["timestamp"] == "2020-02-07 00:00:00"
assert df[1]["timestamp"] == "2020-02-08 00:00:00"

def test_no_timezone(self):

spark = SparkClient()
now = datetime.now()

# Testing a new timezone
spark.conn.conf.set("spark.sql.session.timeZone", "GMT-5")

l = [(now, now)]
rdd = spark.conn.sparkContext.parallelize(l)

schema = StructType(
[
StructField("ts", DataType.TIMESTAMP.spark, True),
StructField("ts_ntz", DataType.TIMESTAMP_NTZ.spark, True),
]
)
df = spark.conn.createDataFrame(rdd, schema)
df.createOrReplaceTempView("temp_tz_table")

df1 = spark.conn.sql("""SELECT ts, ts_ntz FROM temp_tz_table""")
df2 = df1.withColumns(
{"ts": df1.ts.cast(StringType()), "ts_ntz": df1.ts_ntz.cast(StringType())}
)
df2_vals = df2.collect()[0]

assert df2_vals.ts != df2_vals.ts_ntz

# New timezone, the column with TZ must have a different value, and the column NTZ must keep its value
spark.conn.conf.set("spark.sql.session.timeZone", "GMT-7")

df3 = spark.conn.sql("""SELECT ts, ts_ntz FROM temp_tz_table""")
df4 = df3.withColumns(
{"ts": df1.ts.cast(StringType()), "ts_ntz": df1.ts_ntz.cast(StringType())}
)
df4_vals = df4.collect()[0]

assert df4_vals.ts != df2_vals.ts
assert df4_vals.ts_ntz == df2_vals.ts_ntz

0 comments on commit d6998d0

Please sign in to comment.