Skip to content

Commit

Permalink
Merge pull request #37 from privateai/3.7update
Browse files Browse the repository at this point in the history
3.7update
  • Loading branch information
a-guiducci authored Feb 1, 2024
2 parents 62232fc + 3cdf897 commit cc1be03
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 90 deletions.
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@

### Fixed

## [3.7.1] - 2024-02-01

### Added
- Added "ALLOW_TEXT" as a valid type to filter objects

### Changed

### Fixed

## [3.7.0] - 2024-02-01

Expand Down
2 changes: 1 addition & 1 deletion src/privateai_client/__about__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.7.0"
__version__ = "3.7.1"
108 changes: 28 additions & 80 deletions src/privateai_client/components/request_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ def to_dict(self):
if self._issubclass(value):
dict_obj[name] = value.to_dict()
elif type(value) is list:
dict_obj[name] = [
row.to_dict() if self._issubclass(row) else row for row in value
]
dict_obj[name] = [row.to_dict() if self._issubclass(row) else row for row in value]
elif not key.startswith("__") and not callable(key):
dict_obj[name] = value
return dict_obj
Expand All @@ -34,14 +32,14 @@ class AudioOptions(BaseRequestObject):
default_bleep_start_padding: float = 0.5
default_bleep_end_padding: float = 0.5
default_bleep_frequency: Optional[int] = None
default_bleep_gain: Optional[int] = None
default_bleep_gain: Optional[int] = None

def __init__(
self,
bleep_start_padding: float = default_bleep_start_padding,
bleep_end_padding: float = default_bleep_end_padding,
bleep_frequency: Optional[int] = default_bleep_frequency,
bleep_gain: Optional[int] = default_bleep_gain
bleep_gain: Optional[int] = default_bleep_gain,
):
if self._bleep_start_padding_validator(bleep_start_padding):
self._bleep_start_padding = bleep_start_padding
Expand Down Expand Up @@ -90,34 +88,26 @@ def bleep_gain(self, var):

def _bleep_start_padding_validator(self, var):
if type(var) is not float:
raise ValueError(
f"AudioOptions.bleep_start_padding must be of type float, but got {type(var)}"
)
raise ValueError(f"AudioOptions.bleep_start_padding must be of type float, but got {type(var)}")
if var < 0:
raise ValueError("AudioOptions.bleep_start_padding must be positive")
return True

def _bleep_end_padding_validator(self, var):
if type(var) is not float:
raise ValueError(
f"AudioOptions.bleep_end_padding must be of type float, but got {type(var)}"
)
raise ValueError(f"AudioOptions.bleep_end_padding must be of type float, but got {type(var)}")
if var < 0:
raise ValueError("AudioOptions.bleep_end_padding must be positive")
return True

def _bleep_frequency_validator(self, var):
if type(var) is not int and var is not None:
raise ValueError(
f"AudioOptions.bleep_frequency must be of type int or None, but got {type(var)}"
)
raise ValueError(f"AudioOptions.bleep_frequency must be of type int or None, but got {type(var)}")
return True

def _bleep_gain_validator(self, var):
if type(var) is not int and var is not None:
raise ValueError(
f"AudioOptions.bleep_gain must be of type int or None, but got {type(var)}"
)
raise ValueError(f"AudioOptions.bleep_gain must be of type int or None, but got {type(var)}")
return True

@classmethod
Expand Down Expand Up @@ -157,9 +147,7 @@ def text(self, var):

def _processed_text_validator(self, var):
if type(var) is not str:
raise TypeError(
f"{var} is not valid. Entity.processed_text must be of type string"
)
raise TypeError(f"{var} is not valid. Entity.processed_text must be of type string")
return True

def _text_validator(self, var):
Expand All @@ -172,9 +160,7 @@ def fromdict(cls, values: dict):
try:
return cls._fromdict(values)
except TypeError:
raise TypeError(
"Entity can only accept the values 'processed_text' and 'text'"
)
raise TypeError("Entity can only accept the values 'processed_text' and 'text'")


class EntityTypeSelector(BaseRequestObject):
Expand Down Expand Up @@ -212,9 +198,7 @@ def fromdict(cls, values: dict):
try:
return cls._fromdict(values)
except TypeError:
raise TypeError(
"EntityTypeSelector can only accept the values 'type' and 'value'"
)
raise TypeError("EntityTypeSelector can only accept the values 'type' and 'value'")


class File(BaseRequestObject):
Expand Down Expand Up @@ -289,7 +273,7 @@ def fromdict(cls, values: dict):


class FilterSelector(BaseRequestObject):
valid_types = ["ALLOW", "BLOCK"]
valid_types = ["ALLOW", "BLOCK", "ALLOW_TEXT"]
default_threshold = 1

def __init__(
Expand Down Expand Up @@ -320,17 +304,13 @@ def pattern(self):
@property
def entity_type(self):
if self.type != "BLOCK":
raise AttributeError(
f"FilterSelector of type {self.type} does not contain entity_type"
)
raise AttributeError(f"FilterSelector of type {self.type} does not contain entity_type")
return self._entity_type

@property
def threshold(self):
if self.type != "BLOCK":
raise AttributeError(
f"FilterSelector of type {self.type} does not contain threshold"
)
raise AttributeError(f"FilterSelector of type {self.type} does not contain threshold")
return self._threshold

@type.setter
Expand Down Expand Up @@ -380,9 +360,7 @@ def fromdict(cls, values: dict):
try:
return cls._fromdict(values)
except TypeError:
raise TypeError(
"FilterSelector can only accept the values 'type' and 'pattern'"
)
raise TypeError("FilterSelector can only accept the values 'type' and 'pattern'")


class PDFOptions(BaseRequestObject):
Expand Down Expand Up @@ -443,9 +421,7 @@ class ProcessedMarkerText(BaseRequestObject):
]

def __init__(self, pattern: str = default_pattern):
for attribute in (
ProcessedMaskText.attributes + ProcessedSyntheticText.attributes
):
for attribute in ProcessedMaskText.attributes + ProcessedSyntheticText.attributes:
delattr(self, attribute) if hasattr(self, attribute) else False
self._type = "MARKER"
if self._pattern_validator(pattern):
Expand All @@ -472,9 +448,7 @@ class ProcessedMaskText(BaseRequestObject):
attributes = ["_mask_character"]

def __init__(self, mask_character: str = "#"):
for attribute in (
ProcessedMarkerText.attributes + ProcessedSyntheticText.attributes
):
for attribute in ProcessedMarkerText.attributes + ProcessedSyntheticText.attributes:
delattr(self, attribute) if hasattr(self, attribute) else False
if self._mask_character_validator(mask_character):
self._mask_character = mask_character
Expand All @@ -491,9 +465,7 @@ def mask_character(self, var):

def _mask_character_validator(self, var):
if len(var) != 1:
raise ValueError(
f"mask_character must have only one character. {var} has {len(var)} characters."
)
raise ValueError(f"mask_character must have only one character. {var} has {len(var)} characters.")
return True


Expand Down Expand Up @@ -559,9 +531,7 @@ def fromdict(cls, values: dict):
try:
return cls._fromdict(values)
except TypeError:
raise TypeError(
"ProcessedText can only accept the values 'type' and 'pattern'"
)
raise TypeError("ProcessedText can only accept the values 'type' and 'pattern'")

@property
def type(self):
Expand Down Expand Up @@ -676,24 +646,16 @@ def _accuracy_validator(self, var):

def _entity_types_validator(self, var):
if type(var) is not list:
raise TypeError(
f"{var} is not valid. EntityDetection.entity_types can only be a list"
)
raise TypeError(f"{var} is not valid. EntityDetection.entity_types can only be a list")
elif var and not all(isinstance(row, EntityTypeSelector) for row in var):
raise ValueError(
"EntityDetection.entity_types can only contain EntityTypeSelector objects"
)
raise ValueError("EntityDetection.entity_types can only contain EntityTypeSelector objects")
return True

def _filter_validator(self, var):
if type(var) is not list:
raise ValueError(
f"{var} is not valid. EntityDetection.filter can only be a list"
)
raise ValueError(f"{var} is not valid. EntityDetection.filter can only be a list")
elif var and not all(isinstance(x, FilterSelector) for x in var):
raise ValueError(
"EntityDetection.filter can only contain FilterSelector objects"
)
raise ValueError("EntityDetection.filter can only contain FilterSelector objects")
return True

def _return_entity_validator(self, var):
Expand All @@ -703,9 +665,7 @@ def _return_entity_validator(self, var):

def _enable_non_max_suppression_validator(self, var):
if type(var) is not bool:
raise ValueError(
"EntityDetection.enable_non_max_suppression must be of type bool"
)
raise ValueError("EntityDetection.enable_non_max_suppression must be of type bool")
return True

@classmethod
Expand All @@ -714,13 +674,9 @@ def fromdict(cls, values: dict):
initializer_dict = {}
for key, value in values.items():
if key == "entity_types":
initializer_dict[key] = [
EntityTypeSelector.fromdict(row) for row in value
]
initializer_dict[key] = [EntityTypeSelector.fromdict(row) for row in value]
elif key == "filter":
initializer_dict[key] = [
FilterSelector.fromdict(row) for row in value
]
initializer_dict[key] = [FilterSelector.fromdict(row) for row in value]
else:
initializer_dict[key] = value
return cls._fromdict(initializer_dict)
Expand Down Expand Up @@ -839,11 +795,7 @@ def fromdict(cls, values: dict):

class BleepRequest(BaseRequestObject):
def __init__(
self,
file: File,
timestamps: List,
bleep_frequency: Optional[int] = None,
bleep_gain: Optional[int] = None
self, file: File, timestamps: List, bleep_frequency: Optional[int] = None, bleep_gain: Optional[int] = None
):
self.file = file
self.timestamps = timestamps
Expand All @@ -858,9 +810,7 @@ def fromdict(cls, values: dict):
if key == "file":
initializer_dict[key] = File.fromdict(value)
elif key == "timestamps":
initializer_dict[key] = [
Timestamp.fromdict(entry) for entry in value
]
initializer_dict[key] = [Timestamp.fromdict(entry) for entry in value]
else:
initializer_dict[key] = value
return cls._fromdict(initializer_dict)
Expand Down Expand Up @@ -889,9 +839,7 @@ def fromdict(cls, values: dict):
initializer_dict = {}
for key, value in values.items():
if key == "entities":
initializer_dict[key] = [
Entity.fromdict(entity) for entity in values[key]
]
initializer_dict[key] = [Entity.fromdict(entity) for entity in values[key]]
else:
initializer_dict[key] = value
return cls._fromdict(initializer_dict)
Expand Down
15 changes: 12 additions & 3 deletions src/privateai_client/tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ def test_entity_detection_with_block_filter():
assert resp.ok


def test_entity_detection_with_allow_text_filter():
client = _get_client()
filter = rq.filter_selector_obj(type="ALLOW_TEXT", pattern="[A-Za-z0-9]*")
req = rq.process_text_obj(text=["Hey there!"], entity_detection=rq.entity_detection_obj(filter=[filter]))
resp = client.process_text(req)
assert resp.ok


def test_full_entity_detection():
client = _get_client()
filter = rq.filter_selector_obj(type="ALLOW", pattern="[A-Za-z0-9]*")
Expand Down Expand Up @@ -152,13 +160,14 @@ def test_process_file_base64():
resp = client.process_files_base64(request_object=request_obj)
assert resp.ok


def test_process_audio_file_base64():
client = _get_client()

test_dir = "/".join(__file__.split("/")[:-1])
file_name = "test_audio.mp3"
filepath = os.path.join(f"{test_dir}", "test_files", file_name)
file_type= "audio/mp3"
file_type = "audio/mp3"

with open(filepath, "rb") as b64_file:
file_data = base64.b64encode(b64_file.read())
Expand All @@ -170,13 +179,14 @@ def test_process_audio_file_base64():
resp = client.process_files_base64(request_object=request_obj)
assert resp.ok


def test_bleep():
client = _get_client()

test_dir = "/".join(__file__.split("/")[:-1])
file_name = "test_audio.mp3"
filepath = os.path.join(f"{test_dir}", "test_files", file_name)
file_type= "audio/mp3"
file_type = "audio/mp3"

with open(filepath, "rb") as b64_file:
file_data = base64.b64encode(b64_file.read())
Expand All @@ -188,4 +198,3 @@ def test_bleep():
request_obj = rq.bleep_obj(file=file_obj, timestamps=[timestamp], bleep_frequency=500, bleep_gain=-30)
resp = client.bleep(request_object=request_obj)
assert resp.ok

Loading

0 comments on commit cc1be03

Please sign in to comment.