From 71b4f2db48f637b73ab9c322b846f377ed689556 Mon Sep 17 00:00:00 2001 From: Stefan Krawczyk Date: Tue, 19 Dec 2023 11:32:46 -0800 Subject: [PATCH] Extends Pandera validator to handle more DF types We were assuming only pandas annotated functions. This changes that and ensure that functions annotated with a dask datatype will work. Note, added pyspark without adding a test. Don't want to require having pyspark for unit tests just yet... --- .ci/test.sh | 1 + hamilton/data_quality/pandera_validators.py | 30 ++++--- .../pandera/test_pandera_data_quality.py | 82 +++++++++++++++++++ 3 files changed, 103 insertions(+), 10 deletions(-) diff --git a/.ci/test.sh b/.ci/test.sh index 681e8117f..aff9ad210 100755 --- a/.ci/test.sh +++ b/.ci/test.sh @@ -27,6 +27,7 @@ fi if [[ ${TASK} == "integrations" ]]; then pip install -e '.[pandera]' + pip install dask pytest tests/integrations exit 0 fi diff --git a/hamilton/data_quality/pandera_validators.py b/hamilton/data_quality/pandera_validators.py index 86049e632..4c11deeeb 100644 --- a/hamilton/data_quality/pandera_validators.py +++ b/hamilton/data_quality/pandera_validators.py @@ -1,11 +1,13 @@ -from typing import Type +from typing import Any, Type -import pandas as pd import pandera as pa +from hamilton import registry from hamilton.data_quality import base from hamilton.htypes import custom_subclass_check +pandera_supported_extensions = frozenset(["pandas", "dask", "pyspark_pandas"]) + class PanderaDataFrameValidator(base.BaseDefaultValidator): """Pandera schema validator for dataframes""" @@ -16,14 +18,18 @@ def __init__(self, schema: pa.DataFrameSchema, importance: str): @classmethod def applies_to(cls, datatype: Type[Type]) -> bool: - return custom_subclass_check( - datatype, pd.DataFrame - ) # TODO -- allow for modin, etc. as they come for free with pandera + for extension_name in pandera_supported_extensions: + if extension_name in registry.DF_TYPE_AND_COLUMN_TYPES: + df_type = registry.DF_TYPE_AND_COLUMN_TYPES[extension_name][registry.DATAFRAME_TYPE] + result = custom_subclass_check(datatype, df_type) + if result: + return True + return False def description(self) -> str: return "Validates that the returned dataframe matches the pander" - def validate(self, data: pd.DataFrame) -> base.ValidationResult: + def validate(self, data: Any) -> base.ValidationResult: try: result = self.schema.validate(data, lazy=True, inplace=True) if hasattr(result, "dask"): @@ -56,14 +62,18 @@ def __init__(self, schema: pa.SeriesSchema, importance: str): @classmethod def applies_to(cls, datatype: Type[Type]) -> bool: - return custom_subclass_check( - datatype, pd.Series - ) # TODO -- allow for modin, etc. as they come for free with pandera + for extension_name in pandera_supported_extensions: + if extension_name in registry.DF_TYPE_AND_COLUMN_TYPES: + df_type = registry.DF_TYPE_AND_COLUMN_TYPES[extension_name][registry.COLUMN_TYPE] + result = custom_subclass_check(datatype, df_type) + if result: + return True + return False def description(self) -> str: pass - def validate(self, data: pd.Series) -> base.ValidationResult: + def validate(self, data: Any) -> base.ValidationResult: try: result = self.schema.validate(data, lazy=True, inplace=True) if hasattr(result, "dask"): diff --git a/tests/integrations/pandera/test_pandera_data_quality.py b/tests/integrations/pandera/test_pandera_data_quality.py index 64e67a7e6..069b53404 100644 --- a/tests/integrations/pandera/test_pandera_data_quality.py +++ b/tests/integrations/pandera/test_pandera_data_quality.py @@ -1,5 +1,6 @@ import sys +import dask.dataframe as dd import numpy as np import pandas as pd import pandera as pa @@ -149,3 +150,84 @@ def foo() -> pd.DataFrame: n = node.Node.from_fn(foo) with pytest.raises(base.InvalidDecoratorException): h_pandera.check_output().get_validators(n) + + +def test_pandera_decorator_dask_df(): + """Validates that the function can be annotated with a dask dataframe type it'll work appropriately. + + Install dask if this fails. + """ + schema = pa.DataFrameSchema( + { + "year": pa.Column(int, pa.Check(lambda s: s > 2000)), + "month": pa.Column(str), + "day": pa.Column(str), + }, + index=pa.Index(int), + strict=True, + ) + from hamilton.function_modifiers import check_output + + @check_output(schema=schema) + def foo(fail: bool = False) -> dd.DataFrame: + if fail: + return dd.from_pandas( + pd.DataFrame( + { + "year": ["-2001", "-2002", "-2003"], + "month": ["-13", "-6", "120"], + "day": ["700", "-156", "367"], + } + ), + npartitions=1, + ) + return dd.from_pandas( + pd.DataFrame( + { + "year": [2001, 2002, 2003], + "month": ["3", "6", "12"], + "day": ["200", "156", "365"], + } + ), + npartitions=1, + ) + + n = node.Node.from_fn(foo) + validators = check_output(schema=schema).get_validators(n) + assert len(validators) == 1 + (validator,) = validators + result_success = validator.validate(n()) # should not fail + assert result_success.passes + result_success = validator.validate(n(True)) # should fail + assert not result_success.passes + + +def test_pandera_decorator_dask_series(): + """Validates that the function can be annotated with a dask series type it'll work appropriately. + Install dask if this fails. + """ + schema = pa.SeriesSchema( + str, + checks=[ + pa.Check(lambda s: s.str.startswith("foo")), + pa.Check(lambda s: s.str.endswith("bar")), + pa.Check(lambda x: len(x) > 3, element_wise=True), + ], + nullable=False, + ) + from hamilton.function_modifiers import check_output + + @check_output(schema=schema) + def foo(fail: bool = False) -> dd.Series: + if fail: + return dd.from_pandas(pd.Series(["xfoobar", "xfoobox", "xfoobaz"]), npartitions=1) + return dd.from_pandas(pd.Series(["foobar", "foobar", "foobar"]), npartitions=1) + + n = node.Node.from_fn(foo) + validators = check_output(schema=schema).get_validators(n) + assert len(validators) == 1 + (validator,) = validators + result_success = validator.validate(n()) # should not fail + assert result_success.passes + result_success = validator.validate(n(True)) # should fail + assert not result_success.passes