From e570014f6e9f60f9485ffc1749b46bfa7d855de2 Mon Sep 17 00:00:00 2001 From: Stefan Krawczyk Date: Sun, 4 Feb 2024 21:14:32 -0800 Subject: [PATCH] Adds simple str contains and does not contain validator This will enable someone to do: ``` @check_output(contains=["this is a sentence"], does_not_contain=["foo", "bar"]) def llm_call(...) -> str: return response ``` --- hamilton/data_quality/default_validators.py | 69 +++++++++++++++++++++ tests/test_default_data_quality.py | 8 +++ 2 files changed, 77 insertions(+) diff --git a/hamilton/data_quality/default_validators.py b/hamilton/data_quality/default_validators.py index 6219888e6..ec4b7780d 100644 --- a/hamilton/data_quality/default_validators.py +++ b/hamilton/data_quality/default_validators.py @@ -407,6 +407,73 @@ def arg(cls) -> str: return "allow_none" +class StrContainsValidator(base.BaseDefaultValidator): + def __init__(self, contains: Union[str, List[str]], importance: str): + super(StrContainsValidator, self).__init__(importance) + if isinstance(contains, str): + self.contains = [contains] + else: + self.contains = contains + + @classmethod + def applies_to(cls, datatype: Type[Type]) -> bool: + return datatype == str + + def description(self) -> str: + return f"Validates that a string contains [{self.contains}] within it." + + def validate(self, data: str) -> base.ValidationResult: + passes = all([c in data for c in self.contains]) + return base.ValidationResult( + passes=passes, + message=(f"String did not contain {self.contains}" if not passes else "All good."), + diagnostics=( + {"contains": self.contains, "data": data if len(data) < 100 else data[:100]} + if not passes + else {} + ), + ) + + @classmethod + def arg(cls) -> str: + return "contains" + + +class StrDoesNotContainValidator(base.BaseDefaultValidator): + def __init__(self, does_not_contain: Union[str, List[str]], importance: str): + super(StrDoesNotContainValidator, self).__init__(importance) + if isinstance(does_not_contain, str): + self.does_not_contain = [does_not_contain] + else: + self.does_not_contain = does_not_contain + + @classmethod + def applies_to(cls, datatype: Type[Type]) -> bool: + return datatype == str + + def description(self) -> str: + return f"Validates that a string does not contain [{self.does_not_contain}] within it." + + def validate(self, data: str) -> base.ValidationResult: + passes = all([c not in data for c in self.does_not_contain]) + return base.ValidationResult( + passes=passes, + message=(f"String did contain {self.does_not_contain}" if not passes else "All good."), + diagnostics=( + { + "does_not_contain": self.does_not_contain, + "data": data if len(data) < 100 else data[:100], + } + if not passes + else {} + ), + ) + + @classmethod + def arg(cls) -> str: + return "does_not_contain" + + AVAILABLE_DEFAULT_VALIDATORS = [ AllowNaNsValidatorPandasSeries, DataInRangeValidatorPandasSeries, @@ -419,6 +486,8 @@ def arg(cls) -> str: MaxStandardDevValidatorPandasSeries, MeanInRangeValidatorPandasSeries, AllowNoneValidator, + StrContainsValidator, + StrDoesNotContainValidator, ] diff --git a/tests/test_default_data_quality.py b/tests/test_default_data_quality.py index 54d72648a..77e76055b 100644 --- a/tests/test_default_data_quality.py +++ b/tests/test_default_data_quality.py @@ -222,6 +222,14 @@ def test_resolve_default_validators_error(output_type, kwargs, importance): (default_validators.AllowNoneValidator, False, 1, True), (default_validators.AllowNoneValidator, True, None, True), (default_validators.AllowNoneValidator, True, 1, True), + (default_validators.StrContainsValidator, "o b", "foo bar baz", True), + (default_validators.StrContainsValidator, "oof", "foo bar baz", False), + (default_validators.StrContainsValidator, ["o b", "baz"], "foo bar baz", True), + (default_validators.StrContainsValidator, ["oof", "bar"], "foo bar baz", False), + (default_validators.StrDoesNotContainValidator, "o b", "foo bar baz", False), + (default_validators.StrDoesNotContainValidator, "oof", "foo bar baz", True), + (default_validators.StrDoesNotContainValidator, ["o b", "boo"], "foo bar baz", False), + (default_validators.StrDoesNotContainValidator, ["oof", "boo"], "foo bar baz", True), ], ) def test_default_data_validators(