Skip to content

Commit

Permalink
Extends Pandera validator to handle more DF types
Browse files Browse the repository at this point in the history
We were assuming only pandas annotated functions. This changes
that and ensure that functions annotated with a dask datatype will work.

Note, added pyspark without adding a test. Don't want to require having pyspark
for unit tests just yet...
  • Loading branch information
skrawcz committed Dec 21, 2023
1 parent 793f421 commit 71b4f2d
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 10 deletions.
1 change: 1 addition & 0 deletions .ci/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ fi

if [[ ${TASK} == "integrations" ]]; then
pip install -e '.[pandera]'
pip install dask
pytest tests/integrations
exit 0
fi
Expand Down
30 changes: 20 additions & 10 deletions hamilton/data_quality/pandera_validators.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Type
from typing import Any, Type

import pandas as pd
import pandera as pa

from hamilton import registry
from hamilton.data_quality import base
from hamilton.htypes import custom_subclass_check

pandera_supported_extensions = frozenset(["pandas", "dask", "pyspark_pandas"])


class PanderaDataFrameValidator(base.BaseDefaultValidator):
"""Pandera schema validator for dataframes"""
Expand All @@ -16,14 +18,18 @@ def __init__(self, schema: pa.DataFrameSchema, importance: str):

@classmethod
def applies_to(cls, datatype: Type[Type]) -> bool:
return custom_subclass_check(
datatype, pd.DataFrame
) # TODO -- allow for modin, etc. as they come for free with pandera
for extension_name in pandera_supported_extensions:
if extension_name in registry.DF_TYPE_AND_COLUMN_TYPES:
df_type = registry.DF_TYPE_AND_COLUMN_TYPES[extension_name][registry.DATAFRAME_TYPE]
result = custom_subclass_check(datatype, df_type)
if result:
return True
return False

def description(self) -> str:
return "Validates that the returned dataframe matches the pander"

def validate(self, data: pd.DataFrame) -> base.ValidationResult:
def validate(self, data: Any) -> base.ValidationResult:
try:
result = self.schema.validate(data, lazy=True, inplace=True)
if hasattr(result, "dask"):
Expand Down Expand Up @@ -56,14 +62,18 @@ def __init__(self, schema: pa.SeriesSchema, importance: str):

@classmethod
def applies_to(cls, datatype: Type[Type]) -> bool:
return custom_subclass_check(
datatype, pd.Series
) # TODO -- allow for modin, etc. as they come for free with pandera
for extension_name in pandera_supported_extensions:
if extension_name in registry.DF_TYPE_AND_COLUMN_TYPES:
df_type = registry.DF_TYPE_AND_COLUMN_TYPES[extension_name][registry.COLUMN_TYPE]
result = custom_subclass_check(datatype, df_type)
if result:
return True
return False

def description(self) -> str:
pass

def validate(self, data: pd.Series) -> base.ValidationResult:
def validate(self, data: Any) -> base.ValidationResult:
try:
result = self.schema.validate(data, lazy=True, inplace=True)
if hasattr(result, "dask"):
Expand Down
82 changes: 82 additions & 0 deletions tests/integrations/pandera/test_pandera_data_quality.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys

import dask.dataframe as dd
import numpy as np
import pandas as pd
import pandera as pa
Expand Down Expand Up @@ -149,3 +150,84 @@ def foo() -> pd.DataFrame:
n = node.Node.from_fn(foo)
with pytest.raises(base.InvalidDecoratorException):
h_pandera.check_output().get_validators(n)


def test_pandera_decorator_dask_df():
"""Validates that the function can be annotated with a dask dataframe type it'll work appropriately.
Install dask if this fails.
"""
schema = pa.DataFrameSchema(
{
"year": pa.Column(int, pa.Check(lambda s: s > 2000)),
"month": pa.Column(str),
"day": pa.Column(str),
},
index=pa.Index(int),
strict=True,
)
from hamilton.function_modifiers import check_output

@check_output(schema=schema)
def foo(fail: bool = False) -> dd.DataFrame:
if fail:
return dd.from_pandas(
pd.DataFrame(
{
"year": ["-2001", "-2002", "-2003"],
"month": ["-13", "-6", "120"],
"day": ["700", "-156", "367"],
}
),
npartitions=1,
)
return dd.from_pandas(
pd.DataFrame(
{
"year": [2001, 2002, 2003],
"month": ["3", "6", "12"],
"day": ["200", "156", "365"],
}
),
npartitions=1,
)

n = node.Node.from_fn(foo)
validators = check_output(schema=schema).get_validators(n)
assert len(validators) == 1
(validator,) = validators
result_success = validator.validate(n()) # should not fail
assert result_success.passes
result_success = validator.validate(n(True)) # should fail
assert not result_success.passes


def test_pandera_decorator_dask_series():
"""Validates that the function can be annotated with a dask series type it'll work appropriately.
Install dask if this fails.
"""
schema = pa.SeriesSchema(
str,
checks=[
pa.Check(lambda s: s.str.startswith("foo")),
pa.Check(lambda s: s.str.endswith("bar")),
pa.Check(lambda x: len(x) > 3, element_wise=True),
],
nullable=False,
)
from hamilton.function_modifiers import check_output

@check_output(schema=schema)
def foo(fail: bool = False) -> dd.Series:
if fail:
return dd.from_pandas(pd.Series(["xfoobar", "xfoobox", "xfoobaz"]), npartitions=1)
return dd.from_pandas(pd.Series(["foobar", "foobar", "foobar"]), npartitions=1)

n = node.Node.from_fn(foo)
validators = check_output(schema=schema).get_validators(n)
assert len(validators) == 1
(validator,) = validators
result_success = validator.validate(n()) # should not fail
assert result_success.passes
result_success = validator.validate(n(True)) # should fail
assert not result_success.passes

0 comments on commit 71b4f2d

Please sign in to comment.