Skip to content

Commit

Permalink
Adds simple str contains and does not contain validator
Browse files Browse the repository at this point in the history
This will enable someone to do:

```
@check_output(contains=["this is a sentence"], does_not_contain=["foo", "bar"])
def llm_call(...) -> str:
   return response
```
  • Loading branch information
skrawcz authored and elijahbenizzy committed Feb 6, 2024
1 parent 0994226 commit e570014
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 0 deletions.
69 changes: 69 additions & 0 deletions hamilton/data_quality/default_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,73 @@ def arg(cls) -> str:
return "allow_none"


class StrContainsValidator(base.BaseDefaultValidator):
def __init__(self, contains: Union[str, List[str]], importance: str):
super(StrContainsValidator, self).__init__(importance)
if isinstance(contains, str):
self.contains = [contains]
else:
self.contains = contains

@classmethod
def applies_to(cls, datatype: Type[Type]) -> bool:
return datatype == str

def description(self) -> str:
return f"Validates that a string contains [{self.contains}] within it."

def validate(self, data: str) -> base.ValidationResult:
passes = all([c in data for c in self.contains])
return base.ValidationResult(
passes=passes,
message=(f"String did not contain {self.contains}" if not passes else "All good."),
diagnostics=(
{"contains": self.contains, "data": data if len(data) < 100 else data[:100]}
if not passes
else {}
),
)

@classmethod
def arg(cls) -> str:
return "contains"


class StrDoesNotContainValidator(base.BaseDefaultValidator):
def __init__(self, does_not_contain: Union[str, List[str]], importance: str):
super(StrDoesNotContainValidator, self).__init__(importance)
if isinstance(does_not_contain, str):
self.does_not_contain = [does_not_contain]
else:
self.does_not_contain = does_not_contain

@classmethod
def applies_to(cls, datatype: Type[Type]) -> bool:
return datatype == str

def description(self) -> str:
return f"Validates that a string does not contain [{self.does_not_contain}] within it."

def validate(self, data: str) -> base.ValidationResult:
passes = all([c not in data for c in self.does_not_contain])
return base.ValidationResult(
passes=passes,
message=(f"String did contain {self.does_not_contain}" if not passes else "All good."),
diagnostics=(
{
"does_not_contain": self.does_not_contain,
"data": data if len(data) < 100 else data[:100],
}
if not passes
else {}
),
)

@classmethod
def arg(cls) -> str:
return "does_not_contain"


AVAILABLE_DEFAULT_VALIDATORS = [
AllowNaNsValidatorPandasSeries,
DataInRangeValidatorPandasSeries,
Expand All @@ -419,6 +486,8 @@ def arg(cls) -> str:
MaxStandardDevValidatorPandasSeries,
MeanInRangeValidatorPandasSeries,
AllowNoneValidator,
StrContainsValidator,
StrDoesNotContainValidator,
]


Expand Down
8 changes: 8 additions & 0 deletions tests/test_default_data_quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,14 @@ def test_resolve_default_validators_error(output_type, kwargs, importance):
(default_validators.AllowNoneValidator, False, 1, True),
(default_validators.AllowNoneValidator, True, None, True),
(default_validators.AllowNoneValidator, True, 1, True),
(default_validators.StrContainsValidator, "o b", "foo bar baz", True),
(default_validators.StrContainsValidator, "oof", "foo bar baz", False),
(default_validators.StrContainsValidator, ["o b", "baz"], "foo bar baz", True),
(default_validators.StrContainsValidator, ["oof", "bar"], "foo bar baz", False),
(default_validators.StrDoesNotContainValidator, "o b", "foo bar baz", False),
(default_validators.StrDoesNotContainValidator, "oof", "foo bar baz", True),
(default_validators.StrDoesNotContainValidator, ["o b", "boo"], "foo bar baz", False),
(default_validators.StrDoesNotContainValidator, ["oof", "boo"], "foo bar baz", True),
],
)
def test_default_data_validators(
Expand Down

0 comments on commit e570014

Please sign in to comment.