Skip to content

Commit

Permalink
feat: new tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ralphrass committed Jun 13, 2024
1 parent d6998d0 commit b66952e
Showing 1 changed file with 35 additions and 4 deletions.
39 changes: 35 additions & 4 deletions tests/unit/butterfree/transform/features/test_timestamp_feature.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from datetime import datetime

import pytz
from pyspark.sql.types import StringType, StructField, StructType

from butterfree.clients import SparkClient
Expand Down Expand Up @@ -78,16 +79,16 @@ def test_transform_mask(self, feature_set_dataframe_date):
assert df[0]["timestamp"] == "2020-02-07 00:00:00"
assert df[1]["timestamp"] == "2020-02-08 00:00:00"

def test_no_timezone(self):
def test_timezone_configs(self):

spark = SparkClient()
now = datetime.now()

# Testing a new timezone
spark.conn.conf.set("spark.sql.session.timeZone", "GMT-5")

l = [(now, now)]
rdd = spark.conn.sparkContext.parallelize(l)
time_list = [(now, now)]
rdd = spark.conn.sparkContext.parallelize(time_list)

schema = StructType(
[
Expand All @@ -106,7 +107,7 @@ def test_no_timezone(self):

assert df2_vals.ts != df2_vals.ts_ntz

# New timezone, the column with TZ must have a different value, and the column NTZ must keep its value
# New TZ. Column with TZ must have a != value; Column NTZ must keep its value
spark.conn.conf.set("spark.sql.session.timeZone", "GMT-7")

df3 = spark.conn.sql("""SELECT ts, ts_ntz FROM temp_tz_table""")
Expand All @@ -117,3 +118,33 @@ def test_no_timezone(self):

assert df4_vals.ts != df2_vals.ts
assert df4_vals.ts_ntz == df2_vals.ts_ntz

def test_timezone(self):

spark = SparkClient()

my_date = datetime.now(pytz.timezone("US/Pacific"))

datetime_mask = "%Y-%m-%d %H:%M"

data = [
{"id": 1, TIMESTAMP_COLUMN: str(my_date), "feature": 100},
{"id": 2, TIMESTAMP_COLUMN: str(my_date), "feature": 200},
]

df = spark.conn.read.json(spark.conn._sc.parallelize(data, 1))
df.createOrReplaceTempView("time_table")

df2 = spark.sql("SELECT TIMESTAMP AS ts FROM time_table")

time_value = datetime.fromisoformat(df2.collect()[0].ts).strftime(datetime_mask)

df_different_timezone = df2.withColumn(
"ts", df2.ts.cast(DataType.TIMESTAMP.spark)
)
df_no_timezone = df2.withColumn("ts", df2.ts.cast(DataType.TIMESTAMP_NTZ.spark))

assert (
df_different_timezone.collect()[0].ts.strftime(datetime_mask) != time_value
)
assert df_no_timezone.collect()[0].ts.strftime(datetime_mask) == time_value

0 comments on commit b66952e

Please sign in to comment.