Skip to content

Commit

Permalink
fix: test
Browse files Browse the repository at this point in the history
  • Loading branch information
ralphrass committed Aug 15, 2024
1 parent f5c48ef commit f6c4eeb
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 112 deletions.
5 changes: 2 additions & 3 deletions butterfree/clients/spark_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
import json
from typing import Any, Dict, List, Optional, Union

import pyspark
from delta import *
from pyspark.sql import DataFrame, DataFrameReader, SparkSession
from pyspark.sql.streaming import DataStreamReader, StreamingQuery
from pyspark.sql.types import StructType

from butterfree.clients import AbstractClient

import pyspark
from delta import *


class SparkClient(AbstractClient):
"""Handle Spark session connection.
Expand Down
14 changes: 10 additions & 4 deletions butterfree/load/writers/delta_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,17 @@ def merge(
table_exists = client.conn.catalog.tableExists(full_table_name)
# table_is_delta = DeltaTable.isDeltaTable(client.conn, path)


if table_exists:
pd_df = client.conn.sql(f"DESCRIBE TABLE EXTENDED {full_table_name}").toPandas()
provider = pd_df.reset_index().groupby(['col_name'])['data_type'].aggregate('first').Provider
table_is_delta = (provider.lower() == "delta")
pd_df = client.conn.sql(
f"DESCRIBE TABLE EXTENDED {full_table_name}"
).toPandas()
provider = (
pd_df.reset_index()
.groupby(["col_name"])["data_type"]
.aggregate("first")
.Provider
)
table_is_delta = provider.lower() == "delta"

if not table_is_delta:
DeltaWriter()._convert_to_delta(client, full_table_name)
Expand Down
120 changes: 29 additions & 91 deletions tests/unit/butterfree/load/writers/test_delta_writer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
from unittest import mock

import pytest

Expand All @@ -7,38 +8,36 @@

DELTA_LOCATION = "spark-warehouse"


class TestDeltaWriter:

def __checkFileExists(self, file_name: str = "test_delta_table") -> bool:
return os.path.exists(os.path.join(DELTA_LOCATION, file_name))

def test_merge(self):
@pytest.fixture
def merge_builder_mock(self):
builder = mock.MagicMock()
builder.whenMatchedDelete.return_value = builder
builder.whenMatchedUpdateAll.return_value = builder
builder.whenNotMatchedInsertAll.return_value = builder
return builder

client = SparkClient()
def test_merge(self, feature_set_dataframe, merge_builder_mock):

# create_delta_table(client)
client.conn.sql(
"CREATE TABLE test_delta_table (id INT, feature STRING) USING DELTA "
)
client.conn.sql(
"INSERT INTO test_delta_table(id, feature) VALUES(1, 'test') "
)

source = client.conn.createDataFrame([(1, "test2")], ["id", "feature"])
client = SparkClient()
delta_writer = DeltaWriter()
delta_writer.merge = mock.MagicMock()

DeltaWriter().merge(
client,
None,
"test_delta_table",
"spark-warehouse/test_delta_table",
["id"],
source,
feature_set_dataframe,
)

df = client.conn.read.table("test_delta_table")

assert df is not None
assert df.toPandas().feature[0] == "test2"
assert merge_builder_mock.execute.assert_called_once

# Step 2
source = client.conn.createDataFrame(
Expand All @@ -56,92 +55,31 @@ def test_merge(self):
"id > 2",
)

df = client.conn.read.table("test_delta_table")

assert df is not None
assert df.toPandas().feature[0] == "test2"
assert merge_builder_mock.execute.assert_called_once

client.conn.sql("DROP TABLE test_delta_table")
def test_optimize(self, mocker):

def test_merge_from_historical_writer(
self, feature_set, feature_set_dataframe
):
# given
client = SparkClient()
writer = HistoricalFeatureStoreWriter()
client.conn.sql("CREATE SCHEMA test")
client.conn.sql(
"""CREATE TABLE test.feature_set
(id INT, feature STRING, timestamp TIMESTAMP) USING DELTA """
)
client.conn.sql(
"""INSERT INTO test.feature_set(id, feature, timestamp)
VALUES(1, 'test', cast(date_format('2019-12-31', 'yyyy-MM-dd') as timestamp))"""
conn_mock = mocker.patch(
"butterfree.clients.SparkClient.conn", return_value=mock.Mock()
)

# when
writer.write(
feature_set=feature_set,
dataframe=feature_set_dataframe,
spark_client=client,
merge_on=["id", "timestamp"],
)

result_df = client.conn.read.table("test.feature_set")
rpd = result_df.toPandas()
rpd_filtered = rpd.loc[(rpd['id']==1) & (rpd['timestamp'] == '2019-12-31')]

assert result_df is not None
assert str(rpd_filtered.feature.values[0]) == '100'

client.conn.sql("DROP TABLE test.feature_set")
client.conn.sql("DROP SCHEMA test")

def test_optimize(self):

client = SparkClient()
temp_file = "test_delta"

df = client.conn.createDataFrame(
[("a", 1), ("a", 2)], ["key", "value"]
).repartition(1)
df.write.mode("overwrite").format("delta").save(temp_file)
df = client.conn.createDataFrame(
[("a", 3), ("a", 4)], ["key", "value"]
).repartition(1)
df.write.format("delta").save(temp_file, mode="append")
df = client.conn.createDataFrame(
[("b", 1), ("b", 2)], ["key", "value"]
).repartition(1)
df.write.format("delta").save(temp_file, mode="append")

dw = DeltaWriter()

# dw.optimize = MagicMock(spark_client_fixture)
dw.optimize(client)
dw.optimize = mock.MagicMock(client)
dw.optimize(client, "a_table")

# assertions
# dw.optimize.assert_called_once_with(spark_client_fixture)
conn_mock.assert_called_once

def test_vacuum(self):
def test_vacuum(self, mocker):

client = SparkClient()

client.conn.sql(
"CREATE TABLE test_delta_table_v (id INT, feature STRING) USING DELTA "
conn_mock = mocker.patch(
"butterfree.clients.SparkClient.conn", return_value=mock.Mock()
)
client.conn.sql(
"INSERT INTO test_delta_table_v(id, feature) VALUES(1, 'test') "
)

client.conn.conf.set(
"spark.databricks.delta.retentionDurationCheck.enabled", "false"
)

dw = DeltaWriter()
retention_hours = 24
dw.vacuum = mock.MagicMock(client)

DeltaWriter().vacuum("test_delta_table_v", retention_hours, client)

assert self.__checkFileExists("test_delta_table_v") is True
dw.vacuum("a_table", retention_hours, client)

client.conn.sql("DROP TABLE test_delta_table_v")
conn_mock.assert_called_once
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import datetime
import random
from unittest import mock

import pytest
from pyspark.sql.functions import spark_partition_id

from butterfree.clients import SparkClient
from butterfree.load.processing import json_transform
from butterfree.load.writers import HistoricalFeatureStoreWriter, DeltaWriter
from butterfree.load.writers import DeltaWriter, HistoricalFeatureStoreWriter
from butterfree.testing.dataframe import assert_dataframe_equality
from unittest import mock


class TestHistoricalFeatureStoreWriter:
Expand Down Expand Up @@ -146,15 +146,6 @@ def test_write_in_debug_mode_with_interval_mode(
# then
assert_dataframe_equality(historical_feature_set_dataframe, result_df)


@pytest.fixture
def merge_builder_mock(self):
builder = mock.MagicMock()
builder.whenMatchedDelete.return_value = builder
builder.whenMatchedUpdateAll.return_value = builder
builder.whenNotMatchedInsertAll.return_value = builder
return builder

def test_merge_from_historical_writer(
self, feature_set, feature_set_dataframe, mocker, merge_builder_mock
):
Expand All @@ -164,7 +155,9 @@ def test_merge_from_historical_writer(
spark_client.write_table = mocker.stub("write_table")
writer = HistoricalFeatureStoreWriter()

static_mock = mocker.patch("butterfree.load.writers.DeltaWriter.merge", return_value=mock.Mock())
static_mock = mocker.patch(
"butterfree.load.writers.DeltaWriter.merge", return_value=mock.Mock()
)

# when
writer.write(
Expand All @@ -176,7 +169,6 @@ def test_merge_from_historical_writer(

assert static_mock.call_count == 1


def test_validate(self, historical_feature_set_dataframe, mocker, feature_set):
# given
spark_client = mocker.stub("spark_client")
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/butterfree/transform/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest.mock import Mock

import pyspark.pandas as ps
from pyspark.sql import functions
from pytest import fixture

Expand All @@ -13,7 +14,6 @@
SparkFunctionTransform,
)
from butterfree.transform.utils import Function
import pyspark.pandas as ps


def make_dataframe(spark_context, spark_session):
Expand Down

0 comments on commit f6c4eeb

Please sign in to comment.