From 8ba1a600a6b2c4b8d20e9d9ed6d4a771afbfff61 Mon Sep 17 00:00:00 2001 From: Michael Chouinard Date: Wed, 8 Nov 2023 12:31:40 -0500 Subject: [PATCH 1/6] Restucture error responses for Marshmallow --- api/src/api/schemas/extension/__init__.py | 6 + .../api/schemas/extension/field_validators.py | 154 ++++++++ api/src/api/schemas/extension/schema.py | 50 +++ .../api/schemas/extension/schema_common.py | 9 + .../api/schemas/extension/schema_fields.py | 249 +++++++++++++ api/src/validation/__init__.py | 0 api/src/validation/validation_constants.py | 32 ++ api/tests/api/schemas/__init__.py | 0 api/tests/api/schemas/extension/__init__.py | 0 .../schemas/extension/test_schema_fields.py | 87 +++++ .../api/schemas/schema_validation_utils.py | 329 ++++++++++++++++++ 11 files changed, 916 insertions(+) create mode 100644 api/src/api/schemas/extension/__init__.py create mode 100644 api/src/api/schemas/extension/field_validators.py create mode 100644 api/src/api/schemas/extension/schema.py create mode 100644 api/src/api/schemas/extension/schema_common.py create mode 100644 api/src/api/schemas/extension/schema_fields.py create mode 100644 api/src/validation/__init__.py create mode 100644 api/src/validation/validation_constants.py create mode 100644 api/tests/api/schemas/__init__.py create mode 100644 api/tests/api/schemas/extension/__init__.py create mode 100644 api/tests/api/schemas/extension/test_schema_fields.py create mode 100644 api/tests/api/schemas/schema_validation_utils.py diff --git a/api/src/api/schemas/extension/__init__.py b/api/src/api/schemas/extension/__init__.py new file mode 100644 index 000000000..be202f488 --- /dev/null +++ b/api/src/api/schemas/extension/__init__.py @@ -0,0 +1,6 @@ +from . import field_validators as validators +from . import schema_fields as fields +from .schema import Schema +from .schema_common import MarshmallowErrorContainer + +__all__ = ["fields", "validators", "Schema", "MarshmallowErrorContainer"] diff --git a/api/src/api/schemas/extension/field_validators.py b/api/src/api/schemas/extension/field_validators.py new file mode 100644 index 000000000..98be35e62 --- /dev/null +++ b/api/src/api/schemas/extension/field_validators.py @@ -0,0 +1,154 @@ +import copy +import typing + +from apiflask import validators +from marshmallow import ValidationError + +from src.api.schemas.extension.schema_common import MarshmallowErrorContainer +from src.validation.validation_constants import ValidationErrorType + + +class Regexp(validators.Regexp): + REGEX_ERROR = MarshmallowErrorContainer( + ValidationErrorType.FORMAT, "String does not match expected pattern." + ) + + @typing.overload + def __call__(self, value: str) -> str: + ... + + @typing.overload + def __call__(self, value: bytes) -> bytes: + ... + + def __call__(self, value: str | bytes) -> str | bytes: + if self.regex.match(value) is None: # type: ignore + raise ValidationError([self.REGEX_ERROR]) + + return value + + +class Length(validators.Length): + """Validator which succeeds if the value passed to it has a + length between a minimum and maximum. Uses len(), so it + can work for strings, lists, or anything with length. + + :param min: The minimum length. If not provided, minimum length + will not be checked. + :param max: The maximum length. If not provided, maximum length + will not be checked. + :param equal: The exact length. If provided, maximum and minimum + length will not be checked. + :param error: Error message to raise in case of a validation error. + Can be interpolated with `{input}`, `{min}` and `{max}`. + """ + + error_mapping: dict[str, MarshmallowErrorContainer] = { + "message_min": MarshmallowErrorContainer( + ValidationErrorType.MIN_LENGTH, "Shorter than minimum length {min}." + ), + "message_max": MarshmallowErrorContainer( + ValidationErrorType.MAX_LENGTH, "Longer than maximum length {max}." + ), + "message_all": MarshmallowErrorContainer( + ValidationErrorType.MIN_OR_MAX_LENGTH, "Length must be between {min} and {max}." + ), + "message_equal": MarshmallowErrorContainer( + ValidationErrorType.EQUALS, "Length must be {equal}." + ), + } + + def _make_error(self, key: str) -> ValidationError: + try: + # Make a copy of the error mapping so we aren't modifying + # the class-level configurations above when we do formatting + error_container = copy.copy(self.error_mapping[key]) + except KeyError as error: + class_name = self.__class__.__name__ + message = ( + "ValidationError raised by `{class_name}`, but error key `{key}` does " + "not exist in the `error_messages` dictionary." + ).format(class_name=class_name, key=key) + raise AssertionError(message) from error + + error_container.message = error_container.message.format( + min=self.min, max=self.max, equal=self.equal + ) + + return ValidationError([error_container]) + + def __call__(self, value: typing.Sized) -> typing.Sized: + length = len(value) + + if self.equal is not None: + if length != self.equal: + raise self._make_error("message_equal") + return value + + if self.min is not None and length < self.min: + key = "message_min" if self.max is None else "message_all" + raise self._make_error(key) + + if self.max is not None and length > self.max: + key = "message_max" if self.min is None else "message_all" + raise self._make_error(key) + + return value + + +class Email(validators.Email): + EMAIL_ERROR = MarshmallowErrorContainer( + ValidationErrorType.FORMAT, "Not a valid email address." + ) + + def __call__(self, value: str) -> str: + try: + return super().__call__(value) + except ValidationError: + # Fix the validation error to have our format + raise ValidationError([self.EMAIL_ERROR]) from None + + +class OneOf(validators.OneOf): + """ + Validator which succeeds if ``value`` is a member of ``choices``. + + Use this when you want to limit the choices, but don't need the value to be an enum + """ + + CONTAINS_ONLY_ERROR = MarshmallowErrorContainer( + ValidationErrorType.INVALID_CHOICE, "Value must be one of: {choices_text}" + ) + + def __call__(self, value: typing.Any) -> typing.Any: + if value not in self.choices: + error_container = copy.copy(self.CONTAINS_ONLY_ERROR) + error_container.message = error_container.message.format(choices_text=self.choices_text) + raise ValidationError([error_container]) + + return value + + +_T = typing.TypeVar("_T") + + +class Range(validators.Range): + def _format_error(self, value: _T, message: str) -> list[MarshmallowErrorContainer]: # type: ignore + # The method this overrides returns a string, but we'll modify it to return one of + # our error containers instead which works, but MyPy doesn't like. + + is_min = False + is_max = False + if self.min is not None or self.max_inclusive is not None: + is_min = True + if self.max is not None or self.max_inclusive is not None: + is_max = True + + if is_min and is_max: + error_type = ValidationErrorType.MIN_OR_MAX_VALUE + elif is_min: + error_type = ValidationErrorType.MIN_VALUE + else: # must be max, init requires you set something + error_type = ValidationErrorType.MAX_VALUE + + return [MarshmallowErrorContainer(error_type, super()._format_error(value, message))] diff --git a/api/src/api/schemas/extension/schema.py b/api/src/api/schemas/extension/schema.py new file mode 100644 index 000000000..ad00b4851 --- /dev/null +++ b/api/src/api/schemas/extension/schema.py @@ -0,0 +1,50 @@ +from typing import Any, cast + +import apiflask +from marshmallow import EXCLUDE + +from src.api.schemas.extension.schema_common import MarshmallowErrorContainer +from src.validation.validation_constants import ValidationErrorType + + +class Schema(apiflask.Schema): + # There's no clean way to override the error messages at the schema-level + # as they get stored directly into the internal error store of the Schema object + # + # This approach is a little hacky, but we just change the default error messages to + # return the error container objects directly to work around that + _default_error_messages = cast( + dict[str, str], + { + "type": MarshmallowErrorContainer( + key=ValidationErrorType.INVALID, message="Invalid input type." + ), + "unknown": MarshmallowErrorContainer( + key=ValidationErrorType.UNKNOWN, message="Unknown field." + ), + }, + ) + + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + + # In order for the OpenAPI docs to display correctly + # we need to set sub-schemas as partial=True, as the + # apispec library doesn't handle recursively passing that down + # like it should through nested/list objects. + if self.partial is True: + for field in self.declared_fields.values(): + # If the field has nested, then it's a + # Nested field object + if hasattr(field, "nested"): + field.nested.partial = True + + # If the field has inner, then it's a list + # which has a nested schema within it + if hasattr(field, "inner"): + if hasattr(field.inner, "nested"): + field.inner.nested.partial = True + + class Meta: + # Ignore any extra fields + unknown = EXCLUDE diff --git a/api/src/api/schemas/extension/schema_common.py b/api/src/api/schemas/extension/schema_common.py new file mode 100644 index 000000000..92ad58c1f --- /dev/null +++ b/api/src/api/schemas/extension/schema_common.py @@ -0,0 +1,9 @@ +import dataclasses + +from src.validation.validation_constants import ValidationErrorType + + +@dataclasses.dataclass +class MarshmallowErrorContainer: + key: ValidationErrorType + message: str diff --git a/api/src/api/schemas/extension/schema_fields.py b/api/src/api/schemas/extension/schema_fields.py new file mode 100644 index 000000000..aa3e2491f --- /dev/null +++ b/api/src/api/schemas/extension/schema_fields.py @@ -0,0 +1,249 @@ +import copy +import enum +import typing +import uuid + +from apiflask import fields as original_fields +from marshmallow import ValidationError + +from src.api.schemas.extension.field_validators import Range +from src.api.schemas.extension.schema_common import MarshmallowErrorContainer +from src.validation.validation_constants import ValidationErrorType + + +class MixinField(original_fields.Field): + """ + Field mixin class to override the make_error method on each of + our field classes defined below. + + Note that in Python when a class inherits from multiple classes, + the left-most one takes precedence, so if any subclass of Field + were to modify the make_error method, that should take precedence + over this one. + + As make_error is only defined once in the Field class, this is fine + """ + + # Any derived class can specify an error_mapping object + # and it will be used / override the defaults here + error_mapping: dict[str, MarshmallowErrorContainer] = { + "required": MarshmallowErrorContainer( + ValidationErrorType.REQUIRED, "Missing data for required field." + ), + "invalid": MarshmallowErrorContainer(ValidationErrorType.INVALID, "Invalid value."), + "null": MarshmallowErrorContainer(ValidationErrorType.NOT_NULL, "Field may not be null."), + # not sure when this one gets hit, a failed validator uses the validator message + "validator_failed": MarshmallowErrorContainer( + ValidationErrorType.INVALID, "Invalid value." + ), + } + + def __init__(self, allow_none: bool = True, **kwargs: typing.Any) -> None: + super().__init__(allow_none=allow_none, **kwargs) + + # The actual error mapping used for a specific instance + self._error_mapping: dict[str, MarshmallowErrorContainer] = {} + + # This iterates over all classes and updates the error + # mapping with the most-specific class values overriding + # the most generic. + for cls in reversed(self.__class__.__mro__): + # Copy the error mapping values so any alterations don't + # affect other class objects + configured_error_mapping = getattr(cls, "error_mapping", {}) + for k, v in configured_error_mapping.items(): + self._error_mapping[k] = copy.copy(v) + + def make_error(self, key: str, **kwargs: typing.Any) -> ValidationError: + """Helper method to make a `ValidationError` with an error message + from ``self.error_mapping``. + """ + try: + error_container = self._error_mapping[key] + except KeyError as error: + class_name = self.__class__.__name__ + message = ( + "ValidationError raised by `{class_name}`, but error key `{key}` does " + "not exist in the `error_mapping` dictionary." + ).format(class_name=class_name, key=key) + raise AssertionError(message) from error + + if kwargs: + error_container.message = error_container.message.format(**kwargs) + + return ValidationError([error_container]) + + +class String(original_fields.String, MixinField): + error_mapping: dict[str, MarshmallowErrorContainer] = { + "invalid": MarshmallowErrorContainer(ValidationErrorType.INVALID, "Not a valid string."), + "invalid_utf8": MarshmallowErrorContainer( + ValidationErrorType.INVALID, "Not a valid utf-8 string." + ), + } + + +class Integer(original_fields.Integer, MixinField): + error_mapping: dict[str, MarshmallowErrorContainer] = { + "invalid": MarshmallowErrorContainer(ValidationErrorType.INVALID, "Not a valid integer."), + } + + def __init__(self, restrict_to_32bit_int: bool = True, **kwargs: typing.Any): + # By default, we'll restrict all integer values to 32-bits so that they can be stored in + # Postgres' integer column. If you wish to process a larger value, simply set this to false or specify + # your own min/max Range. + if restrict_to_32bit_int: + validators = kwargs.get("validate", []) + + # If a different range is specified, skip adding this one to avoid duplicate error messages + has_range_validator = False + for validator in validators: + if isinstance(validator, Range): + has_range_validator = True + break + + if not has_range_validator: + validators.append(Range(-2147483648, 2147483647)) + kwargs["validate"] = validators + + super().__init__(**kwargs) + + +class Boolean(original_fields.Boolean, MixinField): + error_mapping: dict[str, MarshmallowErrorContainer] = { + "invalid": MarshmallowErrorContainer(ValidationErrorType.INVALID, "Not a valid boolean."), + } + + +class Decimal(original_fields.Decimal, MixinField): + error_mapping: dict[str, MarshmallowErrorContainer] = { + "invalid": MarshmallowErrorContainer(ValidationErrorType.INVALID, "Not a valid decimal."), + "special": MarshmallowErrorContainer( + ValidationErrorType.SPECIAL_NUMERIC, + "Special numeric values (nan or infinity) are not permitted.", + ), + } + + +class UUID(original_fields.UUID, MixinField): + error_mapping: dict[str, MarshmallowErrorContainer] = { + "invalid": MarshmallowErrorContainer(ValidationErrorType.INVALID, "Not a valid UUID."), + "invalid_uuid": MarshmallowErrorContainer(ValidationErrorType.INVALID, "Not a valid UUID."), + } + + def __init__(self, **kwargs: typing.Any): + super().__init__(**kwargs) + self.metadata["example"] = uuid.uuid4() + + +class Date(original_fields.Date, MixinField): + error_mapping: dict[str, MarshmallowErrorContainer] = { + "invalid": MarshmallowErrorContainer(ValidationErrorType.INVALID, "Not a valid date."), + "format": MarshmallowErrorContainer( + ValidationErrorType.FORMAT, "'{input}' cannot be formatted as a date." + ), + } + + +class DateTime(original_fields.DateTime, MixinField): + error_mapping: dict[str, MarshmallowErrorContainer] = { + "invalid": MarshmallowErrorContainer(ValidationErrorType.INVALID, "Not a valid datetime."), + "invalid_awareness": MarshmallowErrorContainer( + ValidationErrorType.INVALID, "Not a valid datetime." + ), + "format": MarshmallowErrorContainer( + ValidationErrorType.FORMAT, "'{input}' cannot be formatted as a datetime." + ), + } + + +class List(original_fields.List, MixinField): + error_mapping: dict[str, MarshmallowErrorContainer] = { + "invalid": MarshmallowErrorContainer(ValidationErrorType.INVALID, "Not a valid list."), + } + + +class Nested(original_fields.Nested, MixinField): + error_mapping: dict[str, MarshmallowErrorContainer] = { + "type": MarshmallowErrorContainer(ValidationErrorType.INVALID, "Invalid type."), + } + + def __init__(self, nested: typing.Any, **kwargs: typing.Any): + super().__init__(nested=nested, **kwargs) + # We set this to object so that if it's nullable, it'll + # get generated in the OpenAPI to allow nullable + type_values = ["object"] + if self.allow_none: + type_values.append("null") + self.metadata["type"] = type_values + + +class Raw(original_fields.Raw, MixinField): + # No error mapping changed from the default + pass + + +class Enum(MixinField): + """ + Custom field class for handling unioning together multiple Python enums into + a single enum field in the generated openapi schema. + + For example, if you have an enum with values x, y, z, and another enum with values a, b, c + using this class all 6 of these values would be possible, and when the value + is deserialized, we would properly convert it to the proper enum object + """ + + error_mapping: dict[str, MarshmallowErrorContainer] = { + "unknown": MarshmallowErrorContainer( + ValidationErrorType.INVALID_CHOICE, "Must be one of: {choices}." + ), + } + + def __init__(self, *enums: typing.Type[enum.Enum], **kwargs: typing.Any) -> None: + super().__init__(**kwargs) + + self.enums = enums + self.field = original_fields.Field() + + self.enum_mapping = {} + + possible_choices = [] + for e in self.enums: + for raw_enum_value in e: + enum_value = str(self.field._serialize(raw_enum_value.value, None, None)) + possible_choices.append(enum_value) + self.enum_mapping[enum_value] = e + + self.choices_text = ", ".join(possible_choices) + # Set the enum metadata + self.metadata["enum"] = possible_choices + # Set the type so Swagger will know it's an enum-string + if self.metadata.get("type") is None: + type_values = ["string"] + if self.allow_none: + type_values.append("null") + self.metadata["type"] = type_values + + def _serialize( + self, value: typing.Any, attr: str | None, obj: typing.Any, **kwargs: typing.Any + ) -> typing.Any: + if value is None: + return None + + val = value.value + return self.field._serialize(val, attr, obj, **kwargs) + + def _deserialize( + self, + value: typing.Any, + attr: str | None, + data: typing.Mapping[str, typing.Any] | None, + **kwargs: typing.Any, + ) -> typing.Any: + val = self.field._deserialize(value, attr, data, **kwargs) + + enum_type = self.enum_mapping.get(val) + if not enum_type: + raise self.make_error("unknown", choices=self.choices_text) + + return enum_type(val) diff --git a/api/src/validation/__init__.py b/api/src/validation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/src/validation/validation_constants.py b/api/src/validation/validation_constants.py new file mode 100644 index 000000000..6fffcba4a --- /dev/null +++ b/api/src/validation/validation_constants.py @@ -0,0 +1,32 @@ +from enum import StrEnum + + +class ValidationErrorType(StrEnum): + """ + Error type codes which clients + need to be aware of in order + to display proper messaging to users. + + *** WARNING *** + Do not adjust these values unless you + are certain that any and all users + are aware of the change, safer to add a new one + """ + + REQUIRED = "required" + NOT_NULL = "not_null" + UNKNOWN = "unknown" + INVALID = "invalid" + + FORMAT = "format" + INVALID_CHOICE = "invalid_choice" + SPECIAL_NUMERIC = "special_numeric" + + MIN_LENGTH = "min_length" + MAX_LENGTH = "max_length" + MIN_OR_MAX_LENGTH = "min_or_max_length" + EQUALS = "equals" + + MIN_VALUE = "min_value" + MAX_VALUE = "max_value" + MIN_OR_MAX_VALUE = "min_or_max_value" diff --git a/api/tests/api/schemas/__init__.py b/api/tests/api/schemas/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/tests/api/schemas/extension/__init__.py b/api/tests/api/schemas/extension/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/api/tests/api/schemas/extension/test_schema_fields.py b/api/tests/api/schemas/extension/test_schema_fields.py new file mode 100644 index 000000000..94cbc3a51 --- /dev/null +++ b/api/tests/api/schemas/extension/test_schema_fields.py @@ -0,0 +1,87 @@ +import inspect + +import pytest +from marshmallow import ValidationError + +from src.api.schemas.extension import fields +from tests.api.schemas.schema_validation_utils import ( + DummySchema, + EnumA, + EnumB, + FieldTestSchema, + get_expected_validation_errors, + get_invalid_field_test_schema_req, + get_valid_field_test_schema_req, + validate_errors, +) + + +def test_enum_field(): + schema = DummySchema() + + both_ab_field = schema.declared_fields["both_ab"] + + # Make sure the multi enum can deserialize to both enums and reserialize to a string + for e in EnumA: + deserialized_value = both_ab_field._deserialize(str(e), None, None) + assert deserialized_value == e + assert isinstance(deserialized_value, EnumA) + + serialized_value = both_ab_field._serialize(e, None, None) + assert isinstance(serialized_value, str) + for e in EnumB: + deserialized_value = both_ab_field._deserialize(str(e), None, None) + assert deserialized_value == e + assert isinstance(deserialized_value, EnumB) + + serialized_value = both_ab_field._serialize(e, None, None) + assert isinstance(serialized_value, str) + + with pytest.raises( + ValidationError, match="Must be one of: value1, value2, value3, value4, value5, value6." + ): + both_ab_field._deserialize("not_a_value", None, None) + + +@pytest.mark.parametrize( + "payload,expected_errors", + [(get_invalid_field_test_schema_req(), get_expected_validation_errors())], +) +def test_fields(payload, expected_errors): + errors = FieldTestSchema().validate(payload) + validate_errors(errors, expected_errors) + + +def test_fields_ignore_unknowns(): + unknown_key = "UNKNOWN" + payload = {**get_valid_field_test_schema_req(), unknown_key: "EXCLUDED"} + result = FieldTestSchema().load(payload) + assert unknown_key not in result + + +def test_fields_configured_properly(): + """ + This is a sanity-test to verify we have properly + overriden and defined all the default error codes + that Marshmallow uses. + + If you see this test failing after updating our + dependency on Marshmallow, likely just need to add + a configuration to the relevant class' "error_mapping" object + """ + relevant_classes = [] + for _, obj in inspect.getmembers(fields): + if inspect.isclass(obj) and issubclass(obj, fields.MixinField): + relevant_classes.append(obj) + + for relevant_class in relevant_classes: + if relevant_class == fields.Enum: + # We don't derive from the original and made a custom enum field + # so the default error messages aren't relevant + assert relevant_class.error_mapping.keys() == {"unknown"} + continue + + # We want to make sure all keys are configured, but we also may have more + required_error_message_keys = relevant_class.default_error_messages.keys() + configured_error_message_keys = relevant_class.error_mapping.keys() + assert configured_error_message_keys >= required_error_message_keys diff --git a/api/tests/api/schemas/schema_validation_utils.py b/api/tests/api/schemas/schema_validation_utils.py new file mode 100644 index 000000000..02fd7038a --- /dev/null +++ b/api/tests/api/schemas/schema_validation_utils.py @@ -0,0 +1,329 @@ +from enum import Enum, StrEnum +from random import choice +from string import ascii_uppercase +from typing import Type + +from src.api.schemas.extension import MarshmallowErrorContainer, Schema, fields, validators +from src.validation.validation_constants import ValidationErrorType + +############################# +# Validation Error Messages +############################# +MISSING_DATA = MarshmallowErrorContainer( + ValidationErrorType.REQUIRED, "Missing data for required field." +) +INVALID_INTEGER = MarshmallowErrorContainer(ValidationErrorType.INVALID, "Not a valid integer.") +INVALID_STRING = MarshmallowErrorContainer(ValidationErrorType.INVALID, "Not a valid string.") +INVALID_STRING_PATTERN = MarshmallowErrorContainer( + ValidationErrorType.FORMAT, "String does not match expected pattern." +) +INVALID_DATE = MarshmallowErrorContainer(ValidationErrorType.INVALID, "Not a valid date.") +INVALID_DATETIME = MarshmallowErrorContainer(ValidationErrorType.INVALID, "Not a valid datetime.") +INVALID_BOOLEAN = MarshmallowErrorContainer(ValidationErrorType.INVALID, "Not a valid boolean.") +INVALID_SCHEMA_MSG = MarshmallowErrorContainer(ValidationErrorType.INVALID, "Invalid input type.") +INVALID_SCHEMA = {"_schema": [INVALID_SCHEMA_MSG]} +INVALID_LIST = MarshmallowErrorContainer(ValidationErrorType.INVALID, "Not a valid list.") +INVALID_UUID = MarshmallowErrorContainer(ValidationErrorType.INVALID, "Not a valid UUID.") +INVALID_DECIMAL = MarshmallowErrorContainer(ValidationErrorType.INVALID, "Not a valid decimal.") +INVALID_SPECIAL_DECIMAL = MarshmallowErrorContainer( + ValidationErrorType.SPECIAL_NUMERIC, + "Special numeric values (nan or infinity) are not permitted.", +) +INVALID_EMAIL = MarshmallowErrorContainer(ValidationErrorType.FORMAT, "Not a valid email address.") +UNKNOWN_FIELD = MarshmallowErrorContainer(ValidationErrorType.UNKNOWN, "Unknown field.") + + +######################## +# Validation Utilities +######################## +def get_random_string(length: int): + return "".join(choice(ascii_uppercase) for i in range(length)) + + +def get_enum_error_msg(*enums: Type[Enum]): + possible_values = [] + for enum in enums: + possible_values.extend([e.value for e in enum]) + + return MarshmallowErrorContainer( + ValidationErrorType.INVALID_CHOICE, f"Must be one of: {', '.join(possible_values)}." + ) + + +def get_one_of_error_msg(choices: list[str]): + choices_text = ", ".join([c for c in choices]) + + return MarshmallowErrorContainer( + ValidationErrorType.INVALID_CHOICE, f"Value must be one of: {choices_text}" + ) + + +def get_min_length_error_msg(length: int): + return MarshmallowErrorContainer( + ValidationErrorType.MIN_LENGTH, f"Shorter than minimum length {length}." + ) + + +def get_max_length_error_msg(length: int): + return MarshmallowErrorContainer( + ValidationErrorType.MAX_LENGTH, f"Longer than maximum length {length}." + ) + + +def get_length_range_error_msg(min: int, max: int): + return MarshmallowErrorContainer( + ValidationErrorType.MIN_OR_MAX_LENGTH, f"Length must be between {min} and {max}." + ) + + +def get_length_equal_error_msg(equal: int): + return MarshmallowErrorContainer(ValidationErrorType.EQUALS, f"Length must be {equal}.") + + +def get_max_or_min_value_error_msg(min: int = -2147483648, max: int = 2147483647): + # defaults are the 32-bit integer min/max + return MarshmallowErrorContainer( + ValidationErrorType.MIN_OR_MAX_VALUE, + f"Must be greater than or equal to {min} and less than or equal to {max}.", + ) + + +def validate_errors(actual_errors, expected_errors): + assert len(actual_errors) == len( + expected_errors + ), f"Expected {len(expected_errors)}, but had {len(actual_errors)} errors" + for field_name in actual_errors: + assert field_name in expected_errors, f"{field_name} in errors but not expected" + assert ( + expected_errors[field_name] == actual_errors[field_name] + ), f"Actual error for {field_name}: {str(actual_errors[field_name])} but received {str(expected_errors[field_name])}" + + +######################## +# Schemas for testing +######################## + + +class EnumA(StrEnum): + VALUE1 = "value1" + VALUE2 = "value2" + VALUE3 = "value3" + + +class EnumB(StrEnum): + VALUE4 = "value4" + VALUE5 = "value5" + VALUE6 = "value6" + + +class DummySchema(Schema): + both_ab = fields.Enum(EnumA, EnumB) + + +class InnerTestSchema(Schema): + inner_str = fields.String() + inner_required_str = fields.String(required=True) + + +class FieldTestSchema(Schema): + field_str = fields.String() + field_str_required = fields.String(required=True) + field_str_min = fields.String(validate=[validators.Length(min=2)]) + field_str_max = fields.String(validate=[validators.Length(max=3)]) + field_str_min_and_max = fields.String(validate=[validators.Length(min=2, max=3)]) + field_str_equal = fields.String(validate=[validators.Length(equal=3)]) + field_str_regex = fields.String(validate=[validators.Regexp("^\\d{3}$")]) + field_str_email = fields.String(validate=[validators.Email()]) + + field_int = fields.Integer() + field_int_required = fields.Integer(required=True) + field_int_strict = fields.Integer(strict=True) + + field_bool = fields.Boolean() + field_bool_required = fields.Boolean(required=True) + + field_decimal = fields.Decimal() + field_decimal_required = fields.Decimal(required=True) + field_decimal_special = fields.Decimal(allow_nan=False) + + field_uuid = fields.UUID() + field_uuid_required = fields.UUID(required=True) + + field_date = fields.Date() + field_date_required = fields.Date(required=True) + field_date_format = fields.Date(format="iso8601") + + field_datetime = fields.DateTime() + field_datetime_required = fields.DateTime(required=True) + field_datetime_format = fields.DateTime(format="iso8601") + + field_list = fields.List(fields.Boolean()) + field_list_required = fields.List(fields.Integer(), required=True) + field_list_indexed = fields.List(fields.Integer()) + + field_nested = fields.Nested(InnerTestSchema()) + field_nested_invalid = fields.Nested(InnerTestSchema()) + field_nested_required = fields.Nested(InnerTestSchema(), required=True) + + field_list_nested = fields.List(fields.Nested(InnerTestSchema())) + field_list_nested_invalid = fields.List(fields.Nested(InnerTestSchema())) + field_list_nested_required = fields.List(fields.Nested(InnerTestSchema()), required=True) + + # There's no "invalid" raw field it doesn't serialize/deserialize + field_raw_required = fields.Raw(required=True) + + field_enum = fields.Enum(EnumA) + field_enum_invalid_choice = fields.Enum(EnumA) + field_enum_required = fields.Enum(EnumB, required=True) + + +######################## +# Requests for the above schema +######################## + + +def get_valid_field_test_schema_req(): + return { + "field_str": "text", + "field_str_required": "text", + "field_str_min": "abcd", + "field_str_max": "a", + "field_str_min_and_max": "ab", + "field_str_equal": "abc", + "field_str_regex": "123", + "field_str_email": "person@example.com", + "field_int": 1, + "field_int_required": 2, + "field_int_strict": 3, + "field_bool": True, + "field_bool_required": False, + "field_decimal": "2.5", + "field_decimal_required": "555", + "field_decimal_special": "4", + "field_uuid": "1234a5b6-7c8d-90ef-1ab2-c3d45678e9f0", + "field_uuid_required": "1234a5b6-7c8d-90ef-1ab2-c3d45678e9f0", + "field_date": "2000-01-01", + "field_date_required": "2010-02-02", + "field_date_format": "2020-03-03", + "field_datetime": "2000-01-01T00:01:01Z", + "field_datetime_required": "2010-02-02T00:02:02Z", + "field_datetime_format": "2020-03-03T00:03:03Z", + "field_list": [True], + "field_list_required": [], + "field_list_indexed": [1, 2, 3], + "field_nested": { + "inner_str": "text", + "inner_required_str": "text", + }, + "field_nested_invalid": { + "inner_str": "text", + "inner_required_str": "text", + }, + "field_nested_required": {"inner_str": "text", "inner_required_str": "present"}, + "field_list_nested": [ + {"inner_str": "text", "inner_required_str": "present"}, + {"inner_str": "text", "inner_required_str": "present"}, + ], + "field_list_nested_invalid": [], + "field_list_nested_required": [], + "field_raw_required": {}, + "field_enum": EnumA.VALUE1, + "field_enum_invalid_choice": EnumA.VALUE2, + "field_enum_required": EnumB.VALUE4, + } + + +def get_invalid_field_test_schema_req(): + return { + "field_str": 1234, + # field_str_required not present + "field_str_min": "a", + "field_str_max": "abcdef", + "field_str_min_and_max": "a", + "field_str_equal": "a", + "field_str_regex": "abc", + "field_str_email": "not an email", + "field_int": {}, + # field_int_required not present + "field_int_strict": "123", + "field_bool": 1234, + # field_bool_required not present + "field_decimal": "hello", + # field_decimal_required not present + "field_decimal_special": "NaN", + "field_uuid": "hello", + # field_uuid_required not present + "field_date": 1234, + # field_date_required not present + "field_date_format": "20220202020202", + "field_datetime": 1234, + # field_datetime_required not present + "field_datetime_format": "02022020 7-20PM PDT", + "field_list": "not_a_list", + # field_list_required not present + "field_list_indexed": ["text", 1, "text"], + "field_nested": { + "inner_str": 1234, + # inner_required_str not present + }, + "field_nested_invalid": 5678, + # field_nested_required not present + "field_list_nested": [ + {"inner_str": 5678, "inner_required_str": "present"}, + {"inner_str": "valid"}, # inner_required_str not present + 54321, + ], + "field_list_nested_invalid": 54321, + # field_list_nested_required not present + # field_raw_required not present + "field_enum": 12345, + "field_enum_invalid_choice": "notvalid", + } + + +def get_expected_validation_errors(): + # This is the expected output of the above + # get_invalid_field_test_schema_req function + return { + "field_str": [INVALID_STRING], + "field_str_required": [MISSING_DATA], + "field_str_min": [get_min_length_error_msg(2)], + "field_str_max": [get_max_length_error_msg(3)], + "field_str_min_and_max": [get_length_range_error_msg(2, 3)], + "field_str_equal": [get_length_equal_error_msg(3)], + "field_str_regex": [INVALID_STRING_PATTERN], + "field_str_email": [INVALID_EMAIL], + "field_int": [INVALID_INTEGER], + "field_int_required": [MISSING_DATA], + "field_int_strict": [INVALID_INTEGER], + "field_bool": [INVALID_BOOLEAN], + "field_bool_required": [MISSING_DATA], + "field_decimal": [INVALID_DECIMAL], + "field_decimal_required": [MISSING_DATA], + "field_decimal_special": [INVALID_SPECIAL_DECIMAL], + "field_uuid": [INVALID_UUID], + "field_uuid_required": [MISSING_DATA], + "field_date": [INVALID_DATE], + "field_date_required": [MISSING_DATA], + "field_date_format": [INVALID_DATE], + "field_datetime": [INVALID_DATETIME], + "field_datetime_required": [MISSING_DATA], + "field_datetime_format": [INVALID_DATETIME], + "field_list": [INVALID_LIST], + "field_list_required": [MISSING_DATA], + "field_list_indexed": {0: [INVALID_INTEGER], 2: [INVALID_INTEGER]}, + "field_nested": {"inner_str": [INVALID_STRING], "inner_required_str": [MISSING_DATA]}, + "field_nested_invalid": INVALID_SCHEMA, + "field_nested_required": [MISSING_DATA], + "field_list_nested": { + 0: {"inner_str": [INVALID_STRING]}, + 1: {"inner_required_str": [MISSING_DATA]}, + 2: INVALID_SCHEMA, + }, + "field_list_nested_invalid": [INVALID_LIST], + "field_list_nested_required": [MISSING_DATA], + "field_raw_required": [MISSING_DATA], + "field_enum": [get_enum_error_msg(EnumA)], + "field_enum_invalid_choice": [get_enum_error_msg(EnumA)], + "field_enum_required": [MISSING_DATA], + } From deb91c33c7e5811c29c4225e021096562fbea4c2 Mon Sep 17 00:00:00 2001 From: Michael Chouinard Date: Thu, 9 Nov 2023 13:04:28 -0500 Subject: [PATCH 2/6] Adding more utilities --- api/src/api/healthcheck.py | 6 +- .../api/opportunities/opportunity_schemas.py | 9 +- api/src/api/response.py | 91 +++++- api/src/api/schemas/response_schema.py | 18 +- api/src/app.py | 12 +- api/src/pagination/pagination_schema.py | 15 +- api/src/util/dict_util.py | 42 +++ api/tests/src/route/test_opportunity_route.py | 31 +- .../src/route/test_route_error_format.py | 283 ++++++++++++++++++ api/tests/src/util/test_dict_util.py | 53 ++++ 10 files changed, 504 insertions(+), 56 deletions(-) create mode 100644 api/src/util/dict_util.py create mode 100644 api/tests/src/route/test_route_error_format.py create mode 100644 api/tests/src/util/test_dict_util.py diff --git a/api/src/api/healthcheck.py b/api/src/api/healthcheck.py index 9883a14a8..767df63d5 100644 --- a/api/src/api/healthcheck.py +++ b/api/src/api/healthcheck.py @@ -8,13 +8,13 @@ import src.adapters.db.flask_db as flask_db from src.api import response -from src.api.schemas import request_schema +from src.api.schemas.extension import fields, Schema logger = logging.getLogger(__name__) -class HealthcheckSchema(request_schema.OrderedSchema): - message: str +class HealthcheckSchema(Schema): + message = fields.String() healthcheck_blueprint = APIBlueprint("healthcheck", __name__, tag="Health") diff --git a/api/src/api/opportunities/opportunity_schemas.py b/api/src/api/opportunities/opportunity_schemas.py index 0dc50c17c..e09ce683f 100644 --- a/api/src/api/opportunities/opportunity_schemas.py +++ b/api/src/api/opportunities/opportunity_schemas.py @@ -1,16 +1,15 @@ from typing import Any -from apiflask import fields +from src.api.schemas.extension import fields, Schema from marshmallow import post_load from src.api.feature_flags.feature_flag import FeatureFlag from src.api.feature_flags.feature_flag_config import FeatureFlagConfig, get_feature_flag_config -from src.api.schemas import request_schema from src.constants.lookup_constants import OpportunityCategory from src.pagination.pagination_schema import PaginationSchema, generate_sorting_schema -class OpportunitySchema(request_schema.OrderedSchema): +class OpportunitySchema(Schema): opportunity_id = fields.Integer( dump_only=True, metadata={"description": "The internal ID of the opportunity", "example": 12345}, @@ -46,7 +45,7 @@ class OpportunitySchema(request_schema.OrderedSchema): updated_at = fields.DateTime(dump_only=True) -class OpportunitySearchSchema(request_schema.OrderedSchema): +class OpportunitySearchSchema(Schema): opportunity_title = fields.String( metadata={ "description": "The title of the opportunity to search for", @@ -81,7 +80,7 @@ class OpportunitySearchSchema(request_schema.OrderedSchema): paging = fields.Nested(PaginationSchema(), required=True) -class OpportunitySearchHeaderSchema(request_schema.OrderedSchema): +class OpportunitySearchHeaderSchema(Schema): # Header field: X-FF-Enable-Opportunity-Log-Msg enable_opportunity_log_msg = fields.Boolean( data_key=FeatureFlag.ENABLE_OPPORTUNITY_LOG_MSG.get_header_name(), diff --git a/api/src/api/response.py b/api/src/api/response.py index 42ddfdbb4..1283aaccc 100644 --- a/api/src/api/response.py +++ b/api/src/api/response.py @@ -1,30 +1,36 @@ import dataclasses -from typing import Any, Optional +import apiflask +import logging +from typing import Any, Optional, Tuple +from src.api.schemas.extension import MarshmallowErrorContainer from src.pagination.pagination_models import PaginationInfo +from src.util.dict_util import flatten_dict +logger = logging.getLogger(__name__) @dataclasses.dataclass class ValidationErrorDetail: type: str message: str = "" - rule: Optional[str] = None field: Optional[str] = None value: Optional[str] = None # Do not store PII data here, as it gets logged in some cases -class ValidationException(Exception): - __slots__ = ["errors", "message", "data"] - +class ValidationException(apiflask.exceptions.HTTPError): def __init__( self, errors: list[ValidationErrorDetail], message: str = "Invalid request", - data: Optional[dict | list[dict]] = None, + detail: Any = None, ): + super().__init__( + status_code=422, + message=message, + detail=detail, + extra_data={"validation_issues": errors}, + ) self.errors = errors - self.message = message - self.data = data or {} @dataclasses.dataclass @@ -38,3 +44,72 @@ class ApiResponse: status_code: int = 200 pagination_info: PaginationInfo | None = None + + +def process_marshmallow_issues(marshmallow_issues: dict) -> list[ValidationErrorDetail]: + validation_errors: list[ValidationErrorDetail] = [] + + # Marshmallow structures its issues as + # {"path": {"to": {"value": ["issue1", "issue2"]}}} + # this flattens that to {"path.to.value": ["issue1", "issue2"]} + flattened_issues = flatten_dict(marshmallow_issues) + + # Take the flattened issues and create properly formatted + # error messages by translating the Marshmallow codes + for field, value in flattened_issues.items(): + if isinstance(value, list): + for item in value: + if not isinstance(item, MarshmallowErrorContainer): + msg = f"Unconfigured error in Marshmallow validation errors, expected MarshmallowErrorContainer, but got {item.__class__.__name__}" + logger.error(msg) + raise AssertionError(msg) + + # If marshmallow expects a field to be an object + # then it adds "._schema", we don't want that so trim it here + validation_errors.append( + ValidationErrorDetail( + field=field.removesuffix("._schema"), + message=item.message, + type=item.key, + ) + ) + else: + logger.error( + "Error format in json section was not formatted as expected, expected a list, got a %s", + type(value), + ) + + return validation_errors + +def restructure_error_response(error: apiflask.exceptions.HTTPError) -> Tuple[dict, int, Any]: + # Note that body needs to have the same schema as the ErrorResponseSchema we defined + # in app.api.route.schemas.response_schema.py + body = { + "message": error.message, + # we rename detail to data so success and error responses are consistent + "data": error.detail, + "status_code": error.status_code, + } + validation_errors: list[ValidationErrorDetail] = [] + + # Process Marshmallow issues and convert them to the proper format + # Marshmallow issues are put in the json error detail - the body of the request + if isinstance(error.detail, dict): + marshmallow_issues = error.detail.get("json") + if marshmallow_issues: + validation_errors.extend(process_marshmallow_issues(marshmallow_issues)) + + # We don't want to make the response confusing + # so we remove the now-duplicate error detail + del body["data"]["json"] + + # If we called raise_flask_error with a list of validation_issues + # then they get appended to the error response here + additional_validation_issues = error.extra_data.get("validation_issues") + if additional_validation_issues: + validation_errors.extend(additional_validation_issues) + + # Attach formatted errors to the error response + body["errors"] = validation_errors + + return body, error.status_code, error.headers diff --git a/api/src/api/schemas/response_schema.py b/api/src/api/schemas/response_schema.py index d5589d5f3..e11b79ca4 100644 --- a/api/src/api/schemas/response_schema.py +++ b/api/src/api/schemas/response_schema.py @@ -1,25 +1,27 @@ -from apiflask import fields -from src.api.schemas import request_schema from src.pagination.pagination_schema import PaginationInfoSchema +from src.api.schemas.extension import fields, Schema -class ValidationErrorSchema(request_schema.OrderedSchema): +class ValidationIssueSchema(Schema): type = fields.String(metadata={"description": "The type of error"}) message = fields.String(metadata={"description": "The message to return"}) - rule = fields.String(metadata={"description": "The rule that failed"}) field = fields.String(metadata={"description": "The field that failed"}) value = fields.String(metadata={"description": "The value that failed"}) -class ResponseSchema(request_schema.OrderedSchema): +class BaseResponseSchema(Schema): message = fields.String(metadata={"description": "The message to return"}) - data = fields.Field(metadata={"description": "The REST resource object"}, dump_default={}) + data = fields.MixinField(metadata={"description": "The REST resource object"}, dump_default={}) status_code = fields.Integer(metadata={"description": "The HTTP status code"}, dump_default=200) - warnings = fields.List(fields.Nested(ValidationErrorSchema()), dump_default=[]) - errors = fields.List(fields.Nested(ValidationErrorSchema()), dump_default=[]) pagination_info = fields.Nested( PaginationInfoSchema(), metadata={"description": "The pagination information for paginated endpoints"}, ) + +class ErrorResponseSchema(BaseResponseSchema): + errors = fields.List(fields.Nested(ValidationIssueSchema()), dump_default=[]) + +class ResponseSchema(BaseResponseSchema): + warnings = fields.List(fields.Nested(ValidationIssueSchema()), dump_default=[]) \ No newline at end of file diff --git a/api/src/app.py b/api/src/app.py index 3ded0d038..15a1ac8e5 100644 --- a/api/src/app.py +++ b/api/src/app.py @@ -1,8 +1,8 @@ import logging import os -from typing import Optional +from typing import Optional, Tuple, Any -from apiflask import APIFlask +from apiflask import APIFlask, exceptions from flask import g from werkzeug.exceptions import Unauthorized @@ -13,6 +13,7 @@ import src.logging.flask_logger as flask_logger from src.api.healthcheck import healthcheck_blueprint from src.api.opportunities import opportunity_blueprint +from src.api.response import restructure_error_response from src.api.schemas import response_schema from src.auth.api_key_auth import User, get_app_security_scheme @@ -50,6 +51,8 @@ def configure_app(app: APIFlask) -> None: # which adds additional details to the object. # https://apiflask.com/schema/#base-response-schema-customization app.config["BASE_RESPONSE_SCHEMA"] = response_schema.ResponseSchema + app.config["HTTP_ERROR_SCHEMA"] = response_schema.ErrorResponseSchema + app.config["VALIDATION_ERROR_SCHEMA"] = response_schema.ErrorResponseSchema # Set a few values for the Swagger endpoint app.config["OPENAPI_VERSION"] = "3.1.0" @@ -73,6 +76,11 @@ def configure_app(app: APIFlask) -> None: app.security_schemes = get_app_security_scheme() + @app.error_processor + def error_processor(error: exceptions.HTTPError) -> Tuple[dict, int, Any]: + return restructure_error_response(error) + + def register_blueprints(app: APIFlask) -> None: app.register_blueprint(healthcheck_blueprint) app.register_blueprint(opportunity_blueprint) diff --git a/api/src/pagination/pagination_schema.py b/api/src/pagination/pagination_schema.py index 13ff86f20..f33d6ce2c 100644 --- a/api/src/pagination/pagination_schema.py +++ b/api/src/pagination/pagination_schema.py @@ -1,15 +1,14 @@ from typing import Type -from apiflask import fields, validators +from src.api.schemas.extension import fields, validators, Schema -from src.api.schemas import request_schema from src.pagination.pagination_models import SortDirection -class PaginationSchema(request_schema.OrderedSchema): +class PaginationSchema(Schema): page_size = fields.Integer( required=True, - validate=validators.Range(min=1), + validate=[validators.Range(min=1)], metadata={"description": "The size of the page to fetch", "example": 25}, ) page_offset = fields.Integer( @@ -21,7 +20,7 @@ class PaginationSchema(request_schema.OrderedSchema): def generate_sorting_schema( cls_name: str, order_by_fields: list[str] | None = None -) -> Type[request_schema.OrderedSchema]: +) -> Type[Schema]: """ Generate a schema that describes the sorting for a pagination endpoint. @@ -30,7 +29,7 @@ def generate_sorting_schema( This is functionally equivalent to specifying your own class like so: - class MySortingSchema(request_schema.OrderedSchema): + class MySortingSchema(Schema): order_by = fields.String( validate=[validators.OneOf(["id","created_at","updated_at"])], required=True, @@ -59,10 +58,10 @@ class MySortingSchema(request_schema.OrderedSchema): metadata={"description": "Whether to sort the response ascending or descending"}, ), } - return request_schema.OrderedSchema.from_dict(ordering_schema_fields, name=cls_name) # type: ignore + return Schema.from_dict(ordering_schema_fields, name=cls_name) # type: ignore -class PaginationInfoSchema(request_schema.OrderedSchema): +class PaginationInfoSchema(Schema): # This is part of the response schema to provide all pagination information back to a user page_offset = fields.Integer( diff --git a/api/src/util/dict_util.py b/api/src/util/dict_util.py new file mode 100644 index 000000000..23bd225af --- /dev/null +++ b/api/src/util/dict_util.py @@ -0,0 +1,42 @@ +from typing import Any + + +def flatten_dict(in_dict: Any, separator: str = ".", prefix: str = "") -> dict: + """ + Takes a set of nested dictionaries and flattens it + + For example:: + + { + "a": { + "b": { + "c": "value_c" + }, + "d": "value_d" + }, + "e": "value_e" + } + + Would become:: + + { + "a.b.c": "value_c", + "a.d": "value_d", + "e": "value_e" + } + """ + + if isinstance(in_dict, dict): + return_dict = {} + # Iterate over each item in the dictionary + for kk, vv in in_dict.items(): + # Flatten each item in the dictionary + for k, v in flatten_dict(vv, separator, str(kk)).items(): + # Update the path + new_key = prefix + separator + str(k) if prefix else str(k) + return_dict[new_key] = v + + return return_dict + + # value isn't a dictionary, so no more recursion + return {prefix: in_dict} diff --git a/api/tests/src/route/test_opportunity_route.py b/api/tests/src/route/test_opportunity_route.py index edeb3a69a..338530f8b 100644 --- a/api/tests/src/route/test_opportunity_route.py +++ b/api/tests/src/route/test_opportunity_route.py @@ -260,34 +260,20 @@ def test_opportunity_search_paging_and_sorting_200( [ ( {}, - { - "paging": ["Missing data for required field."], - "sorting": ["Missing data for required field."], - }, + [{'field': 'sorting', 'message': 'Missing data for required field.', 'type': 'required', 'value': None}, + {'field': 'paging', 'message': 'Missing data for required field.', 'type': 'required', 'value': None}] ), ( get_search_request(page_offset=-1, page_size=-1), - { - "paging": { - "page_offset": ["Must be greater than or equal to 1."], - "page_size": ["Must be greater than or equal to 1."], - } - }, + [{'field': 'paging.page_size', 'message': 'Must be greater than or equal to 1.', 'type': 'min_or_max_value', 'value': None}, {'field': 'paging.page_offset', 'message': 'Must be greater than or equal to 1.', 'type': 'min_or_max_value', 'value': None}], ), ( get_search_request(order_by="fake_field", sort_direction="up"), - { - "sorting": { - "order_by": [ - "Must be one of: opportunity_id, agency, opportunity_number, created_at, updated_at." - ], - "sort_direction": ["Must be one of: ascending, descending."], - } - }, + [{'field': 'sorting.order_by', 'message': 'Value must be one of: opportunity_id, agency, opportunity_number, created_at, updated_at', 'type': 'invalid_choice', 'value': None}, {'field': 'sorting.sort_direction', 'message': 'Must be one of: ascending, descending.', 'type': 'invalid_choice', 'value': None}], ), - (get_search_request(opportunity_title={}), {"opportunity_title": ["Not a valid string."]}), - (get_search_request(category="X"), {"category": ["Must be one of: D, M, C, E, O."]}), - (get_search_request(is_draft="hello"), {"is_draft": ["Not a valid boolean."]}), + (get_search_request(opportunity_title={}), [{'field': 'opportunity_title', 'message': 'Not a valid string.', 'type': 'invalid', 'value': None}]), + (get_search_request(category="X"), [{'field': 'category', 'message': 'Must be one of: D, M, C, E, O.', 'type': 'invalid_choice', 'value': None}]), + (get_search_request(is_draft="hello"), [{'field': 'is_draft', 'message': 'Not a valid boolean.', 'type': 'invalid', 'value': None}]), ], ) def test_opportunity_search_invalid_request_422( @@ -298,7 +284,8 @@ def test_opportunity_search_invalid_request_422( ) assert resp.status_code == 422 - response_data = resp.get_json()["detail"]["json"] + print(resp.get_json()) + response_data = resp.get_json()["errors"] assert response_data == expected_response_data diff --git a/api/tests/src/route/test_route_error_format.py b/api/tests/src/route/test_route_error_format.py new file mode 100644 index 000000000..859a09b0d --- /dev/null +++ b/api/tests/src/route/test_route_error_format.py @@ -0,0 +1,283 @@ +""" +There are several ways errors can be thrown by the API + +These tests aim to verify that the format and structure of the error +responses is consistent and functioning as intended. +""" + +import dataclasses + +import pytest +from apiflask import APIBlueprint +from werkzeug.exceptions import BadRequest, Forbidden, NotFound, Unauthorized +from werkzeug.http import HTTP_STATUS_CODES + +import src.app as app_entry +import src.logging +from src.auth.api_key_auth import api_key_auth +from src.api.response import ValidationErrorDetail +from src.api.utils.route_utils import raise_flask_error +from api.route.schemas.extension import Schema, fields +from api.util.dict_utils import flatten_dict +from tests.api.route.schemas.schema_validation_utils import ( + FieldTestSchema, + get_expected_validation_errors, + get_invalid_field_test_schema_req, + get_valid_field_test_schema_req, +) + +PATH = "/test/" +VALID_UUID = "1234a5b6-7c8d-90ef-1ab2-c3d45678e9f0" +FULL_PATH = PATH + VALID_UUID + + +def header(valid_jwt_token): + return {"X-NJ-UI-User": "nava-test", "X-NJ-UI-Key": valid_jwt_token} + + +class OutputSchema(Schema): + output_val = fields.String() + + +test_blueprint = APIBlueprint("test", __name__, tag="test") + + +class OverridenClass: + """ + In order to arbitrarily change the implementation of + the test endpoint, create a simple function that tests + below can override by doing:: + + def override(self): + # if this method returns, it returns + # the response as a dictionary + a list + # of validation issues to attach to the response + return {"output_val": "hello"}, [] + + monkeypatch.setattr(OverridenClass, "override_method", override) + """ + + def override_method(self): + return {"output_val": "hello"}, [] + + +@test_blueprint.patch("/test/") +@test_blueprint.input(FieldTestSchema) +@test_blueprint.output(OutputSchema) +@test_blueprint.auth_required(api_token_auth) +def api_method(id, req): + resp, warnings = OverridenClass().override_method() + return response.ApiResponse("Test method run successfully", data=resp, warnings=warnings) + + +@pytest.fixture +def simple_app(monkeypatch): + def stub(app): + pass + + # We want all the configurational setup for the app, but + # don't want the DB clients or blueprints to keep setup simpler + monkeypatch.setattr(app_entry, "register_db_clients", stub) + monkeypatch.setattr(app_entry, "register_blueprints", stub) + monkeypatch.setattr(app_entry, "setup_logging", stub) + + app = app_entry.create_app() + + # To avoid re-initializing logging everytime we + # setup the app, we disabled it above and do it here + # in case you want it while running your tests + with api.logging.init(__package__): + yield app + + +@pytest.fixture +def simple_client(simple_app): + simple_app.register_blueprint(test_blueprint) + return simple_app.test_client() + + +@pytest.mark.parametrize( + "exception", [Exception, AttributeError, IndexError, NotImplementedError, ValueError] +) +def test_exception(simple_client, valid_jwt_token, monkeypatch, exception): + def override(self): + raise exception("Exception message text") + + monkeypatch.setattr(OverridenClass, "override_method", override) + + resp = simple_client.patch( + FULL_PATH, json=get_valid_field_test_schema_req(), headers=header(valid_jwt_token) + ) + + assert resp.status_code == 500 + resp_json = resp.get_json() + assert resp_json["data"] == {} + assert resp_json["errors"] == [] + assert resp_json["message"] == "Internal Server Error" + + +@pytest.mark.parametrize("exception", [Unauthorized, NotFound, Forbidden, BadRequest]) +def test_werkzeug_exceptions(simple_client, valid_jwt_token, monkeypatch, exception): + def override(self): + raise exception("Exception message text") + + monkeypatch.setattr(OverridenClass, "override_method", override) + + resp = simple_client.patch( + FULL_PATH, json=get_valid_field_test_schema_req(), headers=header(valid_jwt_token) + ) + + # Werkzeug errors use the proper status code, but + # any message is replaced with a generic one they have defined + assert resp.status_code == exception.code + resp_json = resp.get_json() + assert resp_json["data"] == {} + assert resp_json["errors"] == [] + assert resp_json["message"] == HTTP_STATUS_CODES[exception.code] + + +@pytest.mark.parametrize( + "error_code,message,detail,validation_issues", + [ + (422, "message", {"field": "value"}, []), + ( + 422, + "message but different", + None, + [ + ValidationErrorDetail( + type="example", message="example message", field="example_field", value="value" + ), + ValidationErrorDetail( + type="example2", message="example message2", field="example_field2", value=4 + ), + ], + ), + (401, "not allowed", {"field": "value"}, []), + (403, "bad request message", None, []), + ], +) +def test_flask_error( + simple_client, valid_jwt_token, monkeypatch, error_code, message, detail, validation_issues +): + def override(self): + raise_flask_error(error_code, message, detail=detail, validation_issues=validation_issues) + + monkeypatch.setattr(OverridenClass, "override_method", override) + + resp = simple_client.patch( + FULL_PATH, json=get_valid_field_test_schema_req(), headers=header(valid_jwt_token) + ) + + assert resp.status_code == error_code + resp_json = resp.get_json() + assert resp_json["message"] == message + + if detail is None: + assert resp_json["data"] == {} + else: + assert resp_json["data"] == detail + + if validation_issues: + errors = resp_json["errors"] + assert len(validation_issues) == len(errors) + + for validation_issue in validation_issues: + assert dataclasses.asdict(validation_issue) in errors + else: + assert resp_json["errors"] == [] + + +def test_invalid_path_param(simple_client, valid_jwt_token, monkeypatch): + resp = simple_client.patch( + PATH + "not-a-uuid", json=get_valid_field_test_schema_req(), headers=header(valid_jwt_token) + ) + + # This raises a Werkzeug NotFound so has those values + assert resp.status_code == 404 + resp_json = resp.get_json() + assert resp_json["data"] == {} + assert resp_json["errors"] == [] + assert resp_json["message"] == "Not Found" + + +def test_auth_error(simple_client, monkeypatch): + resp = simple_client.patch( + FULL_PATH, json=get_valid_field_test_schema_req(), headers=header("not_valid_jwt") + ) + + assert resp.status_code == 401 + resp_json = resp.get_json() + assert resp_json["data"] == {} + assert resp_json["errors"] == [] + assert resp_json["message"] == "There was an error verifying token" + + +@pytest.mark.parametrize( + "issues", + [ + [], + [ + ValidationErrorDetail( + type="required", message="Field is required", field="sub_obj.field_a" + ), + ValidationErrorDetail( + type="format", message="Invalid format for type string", field="field_b" + ), + ], + [ValidationErrorDetail(type="bad", message="field is optional technically")], + ], +) +def test_added_validation_issues(simple_client, valid_jwt_token, monkeypatch, issues): + def override(self): + return {"output_val": "hello with validation issues"}, issues + + monkeypatch.setattr(OverridenClass, "override_method", override) + + resp = simple_client.patch( + FULL_PATH, json=get_valid_field_test_schema_req(), headers=header(valid_jwt_token) + ) + + assert resp.status_code == 200 + resp_json = resp.get_json() + assert resp_json["data"] == {"output_val": "hello with validation issues"} + assert resp_json["message"] == "Test method run successfully" + + warnings = resp_json["warnings"] + + assert len(issues) == len(warnings) + for issue in issues: + assert dataclasses.asdict(issue) in warnings + + +def test_marshmallow_validation(simple_client, valid_jwt_token, monkeypatch): + """ + Validate that Marshmallow errors get transformed properly + and attached in the expected format in an error response + """ + + req = get_invalid_field_test_schema_req() + resp = simple_client.patch(FULL_PATH, json=req, headers=header(valid_jwt_token)) + + assert resp.status_code == 422 + resp_json = resp.get_json() + assert resp_json["data"] == {} + assert resp_json["message"] == "Validation error" + + resp_errors = resp_json["errors"] + + expected_errors = [] + for field, errors in flatten_dict(get_expected_validation_errors()).items(): + for error in errors: + expected_errors.append( + { + "type": error.key, + "message": error.message, + "field": field.removesuffix("._schema"), + "value": None, + } + ) + + assert len(expected_errors) == len(resp_errors) + for expected_error in expected_errors: + assert expected_error in resp_errors diff --git a/api/tests/src/util/test_dict_util.py b/api/tests/src/util/test_dict_util.py new file mode 100644 index 000000000..9349dac11 --- /dev/null +++ b/api/tests/src/util/test_dict_util.py @@ -0,0 +1,53 @@ +import pytest + +from src.util.dict_util import flatten_dict + + +@pytest.mark.parametrize( + "data,expected_output", + [ + # Scenario 1 - routine case + ( + {"a": {"b": {"c": "value_c", "f": 5}, "d": "value_d"}, "e": "value_e"}, + {"a.b.c": "value_c", "a.b.f": 5, "a.d": "value_d", "e": "value_e"}, + ), + # Scenario 2 - empty + ({}, {}), + # Scenario 3 - no nesting + ( + { + "a": "1", + "b": 2, + "c": True, + }, + { + "a": "1", + "b": 2, + "c": True, + }, + ), + # Scenario 4 - very nested + ( + { + "a": { + "b": { + "c": { + "d": { + "e": { + "f": {"g": {"h1": "h1_value", "h2": ["h2_value1", "h2_value2"]}} + } + } + } + } + } + }, + {"a.b.c.d.e.f.g.h1": "h1_value", "a.b.c.d.e.f.g.h2": ["h2_value1", "h2_value2"]}, + ), + # Scenario 5 - dictionaries inside non-dictionaries aren't flattened + ({"a": {"b": [{"list_dict_a": "a"}]}}, {"a.b": [{"list_dict_a": "a"}]}), + # Scenario 6 - integer keys should be allowed too + ({"a": {0: {"b": "b_value"}, 1: "c"}}, {"a.0.b": "b_value", "a.1": "c"}), + ], +) +def test_flatten_dict(data, expected_output): + assert flatten_dict(data) == expected_output From e53a180c463caade4c754b34ab9cbc17b4e1fbef Mon Sep 17 00:00:00 2001 From: Michael Chouinard Date: Thu, 9 Nov 2023 14:22:59 -0500 Subject: [PATCH 3/6] Last bits of cleanup --- api/src/api/healthcheck.py | 2 +- .../api/opportunities/opportunity_schemas.py | 4 +- api/src/api/response.py | 11 ++- api/src/api/route_utils.py | 8 +- .../api/schemas/extension/schema_fields.py | 4 +- api/src/api/schemas/response_schema.py | 15 +-- api/src/app.py | 20 ++-- api/src/pagination/pagination_schema.py | 6 +- api/tests/src/route/test_opportunity_route.py | 93 +++++++++++++++++-- .../src/route/test_route_error_format.py | 55 +++++------ 10 files changed, 155 insertions(+), 63 deletions(-) diff --git a/api/src/api/healthcheck.py b/api/src/api/healthcheck.py index 767df63d5..22f944273 100644 --- a/api/src/api/healthcheck.py +++ b/api/src/api/healthcheck.py @@ -8,7 +8,7 @@ import src.adapters.db.flask_db as flask_db from src.api import response -from src.api.schemas.extension import fields, Schema +from src.api.schemas.extension import Schema, fields logger = logging.getLogger(__name__) diff --git a/api/src/api/opportunities/opportunity_schemas.py b/api/src/api/opportunities/opportunity_schemas.py index e09ce683f..fff3203c4 100644 --- a/api/src/api/opportunities/opportunity_schemas.py +++ b/api/src/api/opportunities/opportunity_schemas.py @@ -1,10 +1,10 @@ from typing import Any -from src.api.schemas.extension import fields, Schema from marshmallow import post_load from src.api.feature_flags.feature_flag import FeatureFlag from src.api.feature_flags.feature_flag_config import FeatureFlagConfig, get_feature_flag_config +from src.api.schemas.extension import Schema, fields from src.constants.lookup_constants import OpportunityCategory from src.pagination.pagination_schema import PaginationSchema, generate_sorting_schema @@ -30,7 +30,6 @@ class OpportunitySchema(Schema): category = fields.Enum( OpportunityCategory, - by_value=True, metadata={ "description": "The opportunity category", "example": OpportunityCategory.DISCRETIONARY, @@ -54,7 +53,6 @@ class OpportunitySearchSchema(Schema): ) category = fields.Enum( OpportunityCategory, - by_value=True, metadata={ "description": "The opportunity category to search for", "example": OpportunityCategory.DISCRETIONARY, diff --git a/api/src/api/response.py b/api/src/api/response.py index 1283aaccc..5008e8e6c 100644 --- a/api/src/api/response.py +++ b/api/src/api/response.py @@ -1,14 +1,16 @@ import dataclasses -import apiflask import logging from typing import Any, Optional, Tuple +import apiflask + from src.api.schemas.extension import MarshmallowErrorContainer from src.pagination.pagination_models import PaginationInfo from src.util.dict_util import flatten_dict logger = logging.getLogger(__name__) + @dataclasses.dataclass class ValidationErrorDetail: type: str @@ -81,6 +83,7 @@ def process_marshmallow_issues(marshmallow_issues: dict) -> list[ValidationError return validation_errors + def restructure_error_response(error: apiflask.exceptions.HTTPError) -> Tuple[dict, int, Any]: # Note that body needs to have the same schema as the ErrorResponseSchema we defined # in app.api.route.schemas.response_schema.py @@ -103,6 +106,12 @@ def restructure_error_response(error: apiflask.exceptions.HTTPError) -> Tuple[di # so we remove the now-duplicate error detail del body["data"]["json"] + marshmallow_issues = error.detail.get("headers") + if marshmallow_issues: + validation_errors.extend(process_marshmallow_issues(marshmallow_issues)) + + del body["data"]["headers"] + # If we called raise_flask_error with a list of validation_issues # then they get appended to the error response here additional_validation_issues = error.extra_data.get("validation_issues") diff --git a/api/src/api/route_utils.py b/api/src/api/route_utils.py index 1dccebcd3..704f628da 100644 --- a/api/src/api/route_utils.py +++ b/api/src/api/route_utils.py @@ -3,16 +3,20 @@ from apiflask import abort from apiflask.types import ResponseHeaderType +from src.api.response import ValidationErrorDetail + def raise_flask_error( # type: ignore status_code: int, message: str | None = None, detail: Any = None, headers: ResponseHeaderType | None = None, - # TODO - when we work on validation error responses, we'll want to take in those here + validation_issues: list[ValidationErrorDetail] | None = None, ) -> Never: # Wrapper around the abort method which makes an error during API processing # work properly when APIFlask generates a response. # mypy doesn't realize this method never returns, so we define the same method # with a return type of Never. - abort(status_code, message, detail, headers) + abort( + status_code, message, detail, headers, extra_data={"validation_issues": validation_issues} + ) diff --git a/api/src/api/schemas/extension/schema_fields.py b/api/src/api/schemas/extension/schema_fields.py index aa3e2491f..7409ecb40 100644 --- a/api/src/api/schemas/extension/schema_fields.py +++ b/api/src/api/schemas/extension/schema_fields.py @@ -38,8 +38,8 @@ class MixinField(original_fields.Field): ), } - def __init__(self, allow_none: bool = True, **kwargs: typing.Any) -> None: - super().__init__(allow_none=allow_none, **kwargs) + def __init__(self, **kwargs: typing.Any) -> None: + super().__init__(**kwargs) # The actual error mapping used for a specific instance self._error_mapping: dict[str, MarshmallowErrorContainer] = {} diff --git a/api/src/api/schemas/response_schema.py b/api/src/api/schemas/response_schema.py index e11b79ca4..9360c0c0e 100644 --- a/api/src/api/schemas/response_schema.py +++ b/api/src/api/schemas/response_schema.py @@ -1,6 +1,5 @@ - +from src.api.schemas.extension import Schema, fields from src.pagination.pagination_schema import PaginationInfoSchema -from src.api.schemas.extension import fields, Schema class ValidationIssueSchema(Schema): @@ -15,13 +14,15 @@ class BaseResponseSchema(Schema): data = fields.MixinField(metadata={"description": "The REST resource object"}, dump_default={}) status_code = fields.Integer(metadata={"description": "The HTTP status code"}, dump_default=200) - pagination_info = fields.Nested( - PaginationInfoSchema(), - metadata={"description": "The pagination information for paginated endpoints"}, - ) class ErrorResponseSchema(BaseResponseSchema): errors = fields.List(fields.Nested(ValidationIssueSchema()), dump_default=[]) + class ResponseSchema(BaseResponseSchema): - warnings = fields.List(fields.Nested(ValidationIssueSchema()), dump_default=[]) \ No newline at end of file + pagination_info = fields.Nested( + PaginationInfoSchema(), + metadata={"description": "The pagination information for paginated endpoints"}, + ) + + warnings = fields.List(fields.Nested(ValidationIssueSchema()), dump_default=[]) diff --git a/api/src/app.py b/api/src/app.py index 15a1ac8e5..19425d6a0 100644 --- a/api/src/app.py +++ b/api/src/app.py @@ -1,6 +1,6 @@ import logging import os -from typing import Optional, Tuple, Any +from typing import Any, Optional, Tuple from apiflask import APIFlask, exceptions from flask import g @@ -23,11 +23,8 @@ def create_app() -> APIFlask: app = APIFlask(__name__) - src.logging.init(__package__) - flask_logger.init_app(logging.root, app) - - db_client = db.PostgresDBClient() - flask_db.register_db_client(db_client, app) + setup_logging(app) + register_db_client(app) feature_flag_config.initialize() @@ -46,6 +43,16 @@ def current_user(is_user_expected: bool = True) -> Optional[User]: return current +def setup_logging(app: APIFlask) -> None: + src.logging.init(__package__) + flask_logger.init_app(logging.root, app) + + +def register_db_client(app: APIFlask) -> None: + db_client = db.PostgresDBClient() + flask_db.register_db_client(db_client, app) + + def configure_app(app: APIFlask) -> None: # Modify the response schema to instead use the format of our ApiResponse class # which adds additional details to the object. @@ -75,7 +82,6 @@ def configure_app(app: APIFlask) -> None: # See: https://apiflask.com/authentication/#use-external-authentication-library app.security_schemes = get_app_security_scheme() - @app.error_processor def error_processor(error: exceptions.HTTPError) -> Tuple[dict, int, Any]: return restructure_error_response(error) diff --git a/api/src/pagination/pagination_schema.py b/api/src/pagination/pagination_schema.py index f33d6ce2c..bf2cf9a31 100644 --- a/api/src/pagination/pagination_schema.py +++ b/api/src/pagination/pagination_schema.py @@ -1,7 +1,6 @@ from typing import Type -from src.api.schemas.extension import fields, validators, Schema - +from src.api.schemas.extension import Schema, fields, validators from src.pagination.pagination_models import SortDirection @@ -38,7 +37,6 @@ class MySortingSchema(Schema): sort_direction = fields.Enum( SortDirection, required=True, - by_value=True, metadata={"description": "Whether to sort the response ascending or descending"}, ) """ @@ -54,7 +52,6 @@ class MySortingSchema(Schema): "sort_direction": fields.Enum( SortDirection, required=True, - by_value=True, metadata={"description": "Whether to sort the response ascending or descending"}, ), } @@ -81,6 +78,5 @@ class PaginationInfoSchema(Schema): ) sort_direction = fields.Enum( SortDirection, - by_value=True, metadata={"description": "The direction the records are sorted"}, ) diff --git a/api/tests/src/route/test_opportunity_route.py b/api/tests/src/route/test_opportunity_route.py index c06b0a74e..1da81f0c8 100644 --- a/api/tests/src/route/test_opportunity_route.py +++ b/api/tests/src/route/test_opportunity_route.py @@ -265,20 +265,88 @@ def test_opportunity_search_paging_and_sorting_200( [ ( {}, - [{'field': 'sorting', 'message': 'Missing data for required field.', 'type': 'required', 'value': None}, - {'field': 'paging', 'message': 'Missing data for required field.', 'type': 'required', 'value': None}] + [ + { + "field": "sorting", + "message": "Missing data for required field.", + "type": "required", + "value": None, + }, + { + "field": "paging", + "message": "Missing data for required field.", + "type": "required", + "value": None, + }, + ], ), ( get_search_request(page_offset=-1, page_size=-1), - [{'field': 'paging.page_size', 'message': 'Must be greater than or equal to 1.', 'type': 'min_or_max_value', 'value': None}, {'field': 'paging.page_offset', 'message': 'Must be greater than or equal to 1.', 'type': 'min_or_max_value', 'value': None}], + [ + { + "field": "paging.page_size", + "message": "Must be greater than or equal to 1.", + "type": "min_or_max_value", + "value": None, + }, + { + "field": "paging.page_offset", + "message": "Must be greater than or equal to 1.", + "type": "min_or_max_value", + "value": None, + }, + ], ), ( get_search_request(order_by="fake_field", sort_direction="up"), - [{'field': 'sorting.order_by', 'message': 'Value must be one of: opportunity_id, agency, opportunity_number, created_at, updated_at', 'type': 'invalid_choice', 'value': None}, {'field': 'sorting.sort_direction', 'message': 'Must be one of: ascending, descending.', 'type': 'invalid_choice', 'value': None}], + [ + { + "field": "sorting.order_by", + "message": "Value must be one of: opportunity_id, agency, opportunity_number, created_at, updated_at", + "type": "invalid_choice", + "value": None, + }, + { + "field": "sorting.sort_direction", + "message": "Must be one of: ascending, descending.", + "type": "invalid_choice", + "value": None, + }, + ], + ), + ( + get_search_request(opportunity_title={}), + [ + { + "field": "opportunity_title", + "message": "Not a valid string.", + "type": "invalid", + "value": None, + } + ], + ), + ( + get_search_request(category="X"), + [ + { + "field": "category", + "message": "Must be one of: D, M, C, E, O.", + "type": "invalid_choice", + "value": None, + } + ], + ), + ( + get_search_request(is_draft="hello"), + [ + { + "field": "is_draft", + "message": "Not a valid boolean.", + "type": "invalid", + "value": None, + } + ], ), - (get_search_request(opportunity_title={}), [{'field': 'opportunity_title', 'message': 'Not a valid string.', 'type': 'invalid', 'value': None}]), - (get_search_request(category="X"), [{'field': 'category', 'message': 'Must be one of: D, M, C, E, O.', 'type': 'invalid_choice', 'value': None}]), - (get_search_request(is_draft="hello"), [{'field': 'is_draft', 'message': 'Not a valid boolean.', 'type': 'invalid', 'value': None}]), ], ) def test_opportunity_search_invalid_request_422( @@ -324,8 +392,15 @@ def test_opportunity_search_feature_flag_invalid_value_422( resp = client.post("/v1/opportunities/search", json=get_search_request(), headers=headers) assert resp.status_code == 422 - response_data = resp.get_json()["detail"]["headers"] - assert response_data == {"X-FF-Enable-Opportunity-Log-Msg": ["Not a valid boolean."]} + response_data = resp.get_json()["errors"] + assert response_data == [ + { + "field": "X-FF-Enable-Opportunity-Log-Msg", + "message": "Not a valid boolean.", + "type": "invalid", + "value": None, + } + ] ##################################### diff --git a/api/tests/src/route/test_route_error_format.py b/api/tests/src/route/test_route_error_format.py index 859a09b0d..d5683ef19 100644 --- a/api/tests/src/route/test_route_error_format.py +++ b/api/tests/src/route/test_route_error_format.py @@ -14,12 +14,12 @@ import src.app as app_entry import src.logging +from src.api.response import ApiResponse, ValidationErrorDetail +from src.api.route_utils import raise_flask_error +from src.api.schemas.extension import Schema, fields from src.auth.api_key_auth import api_key_auth -from src.api.response import ValidationErrorDetail -from src.api.utils.route_utils import raise_flask_error -from api.route.schemas.extension import Schema, fields -from api.util.dict_utils import flatten_dict -from tests.api.route.schemas.schema_validation_utils import ( +from src.util.dict_util import flatten_dict +from tests.api.schemas.schema_validation_utils import ( FieldTestSchema, get_expected_validation_errors, get_invalid_field_test_schema_req, @@ -31,8 +31,8 @@ FULL_PATH = PATH + VALID_UUID -def header(valid_jwt_token): - return {"X-NJ-UI-User": "nava-test", "X-NJ-UI-Key": valid_jwt_token} +def header(api_auth_token): + return {"X-Auth": api_auth_token} class OutputSchema(Schema): @@ -62,12 +62,12 @@ def override_method(self): @test_blueprint.patch("/test/") -@test_blueprint.input(FieldTestSchema) +@test_blueprint.input(FieldTestSchema, arg_name="req") @test_blueprint.output(OutputSchema) -@test_blueprint.auth_required(api_token_auth) -def api_method(id, req): +@test_blueprint.auth_required(api_key_auth) +def api_method(test_id, req): resp, warnings = OverridenClass().override_method() - return response.ApiResponse("Test method run successfully", data=resp, warnings=warnings) + return ApiResponse("Test method run successfully", data=resp, warnings=warnings) @pytest.fixture @@ -77,7 +77,7 @@ def stub(app): # We want all the configurational setup for the app, but # don't want the DB clients or blueprints to keep setup simpler - monkeypatch.setattr(app_entry, "register_db_clients", stub) + monkeypatch.setattr(app_entry, "register_db_client", stub) monkeypatch.setattr(app_entry, "register_blueprints", stub) monkeypatch.setattr(app_entry, "setup_logging", stub) @@ -86,7 +86,7 @@ def stub(app): # To avoid re-initializing logging everytime we # setup the app, we disabled it above and do it here # in case you want it while running your tests - with api.logging.init(__package__): + with src.logging.init(__package__): yield app @@ -99,14 +99,14 @@ def simple_client(simple_app): @pytest.mark.parametrize( "exception", [Exception, AttributeError, IndexError, NotImplementedError, ValueError] ) -def test_exception(simple_client, valid_jwt_token, monkeypatch, exception): +def test_exception(simple_client, api_auth_token, monkeypatch, exception): def override(self): raise exception("Exception message text") monkeypatch.setattr(OverridenClass, "override_method", override) resp = simple_client.patch( - FULL_PATH, json=get_valid_field_test_schema_req(), headers=header(valid_jwt_token) + FULL_PATH, json=get_valid_field_test_schema_req(), headers=header(api_auth_token) ) assert resp.status_code == 500 @@ -117,14 +117,14 @@ def override(self): @pytest.mark.parametrize("exception", [Unauthorized, NotFound, Forbidden, BadRequest]) -def test_werkzeug_exceptions(simple_client, valid_jwt_token, monkeypatch, exception): +def test_werkzeug_exceptions(simple_client, api_auth_token, monkeypatch, exception): def override(self): raise exception("Exception message text") monkeypatch.setattr(OverridenClass, "override_method", override) resp = simple_client.patch( - FULL_PATH, json=get_valid_field_test_schema_req(), headers=header(valid_jwt_token) + FULL_PATH, json=get_valid_field_test_schema_req(), headers=header(api_auth_token) ) # Werkzeug errors use the proper status code, but @@ -158,7 +158,7 @@ def override(self): ], ) def test_flask_error( - simple_client, valid_jwt_token, monkeypatch, error_code, message, detail, validation_issues + simple_client, api_auth_token, monkeypatch, error_code, message, detail, validation_issues ): def override(self): raise_flask_error(error_code, message, detail=detail, validation_issues=validation_issues) @@ -166,7 +166,7 @@ def override(self): monkeypatch.setattr(OverridenClass, "override_method", override) resp = simple_client.patch( - FULL_PATH, json=get_valid_field_test_schema_req(), headers=header(valid_jwt_token) + FULL_PATH, json=get_valid_field_test_schema_req(), headers=header(api_auth_token) ) assert resp.status_code == error_code @@ -188,9 +188,9 @@ def override(self): assert resp_json["errors"] == [] -def test_invalid_path_param(simple_client, valid_jwt_token, monkeypatch): +def test_invalid_path_param(simple_client, api_auth_token, monkeypatch): resp = simple_client.patch( - PATH + "not-a-uuid", json=get_valid_field_test_schema_req(), headers=header(valid_jwt_token) + PATH + "not-a-uuid", json=get_valid_field_test_schema_req(), headers=header(api_auth_token) ) # This raises a Werkzeug NotFound so has those values @@ -210,7 +210,10 @@ def test_auth_error(simple_client, monkeypatch): resp_json = resp.get_json() assert resp_json["data"] == {} assert resp_json["errors"] == [] - assert resp_json["message"] == "There was an error verifying token" + assert ( + resp_json["message"] + == "The server could not verify that you are authorized to access the URL requested" + ) @pytest.mark.parametrize( @@ -228,14 +231,14 @@ def test_auth_error(simple_client, monkeypatch): [ValidationErrorDetail(type="bad", message="field is optional technically")], ], ) -def test_added_validation_issues(simple_client, valid_jwt_token, monkeypatch, issues): +def test_added_validation_issues(simple_client, api_auth_token, monkeypatch, issues): def override(self): return {"output_val": "hello with validation issues"}, issues monkeypatch.setattr(OverridenClass, "override_method", override) resp = simple_client.patch( - FULL_PATH, json=get_valid_field_test_schema_req(), headers=header(valid_jwt_token) + FULL_PATH, json=get_valid_field_test_schema_req(), headers=header(api_auth_token) ) assert resp.status_code == 200 @@ -250,14 +253,14 @@ def override(self): assert dataclasses.asdict(issue) in warnings -def test_marshmallow_validation(simple_client, valid_jwt_token, monkeypatch): +def test_marshmallow_validation(simple_client, api_auth_token, monkeypatch): """ Validate that Marshmallow errors get transformed properly and attached in the expected format in an error response """ req = get_invalid_field_test_schema_req() - resp = simple_client.patch(FULL_PATH, json=req, headers=header(valid_jwt_token)) + resp = simple_client.patch(FULL_PATH, json=req, headers=header(api_auth_token)) assert resp.status_code == 422 resp_json = resp.get_json() From 87b636c3fa16c8515032bbe052429a6fbd0a381b Mon Sep 17 00:00:00 2001 From: Michael Chouinard Date: Thu, 9 Nov 2023 15:34:36 -0500 Subject: [PATCH 4/6] Adding the openapi update file --- api/openapi.generated.yml | 153 ++++++++++++++++++++++++-------------- 1 file changed, 99 insertions(+), 54 deletions(-) diff --git a/api/openapi.generated.yml b/api/openapi.generated.yml index 3e6f1c652..bd6b54978 100644 --- a/api/openapi.generated.yml +++ b/api/openapi.generated.yml @@ -27,25 +27,28 @@ paths: $ref: '#/components/schemas/Healthcheck' status_code: type: integer + minimum: -2147483648 + maximum: 2147483647 description: The HTTP status code - warnings: - type: array - items: - $ref: '#/components/schemas/ValidationError' - errors: - type: array - items: - $ref: '#/components/schemas/ValidationError' pagination_info: description: The pagination information for paginated endpoints + type: + - object allOf: - $ref: '#/components/schemas/PaginationInfo' + warnings: + type: array + items: + type: + - object + allOf: + - $ref: '#/components/schemas/ValidationIssue' description: Successful response '503': content: application/json: schema: - $ref: '#/components/schemas/HTTPError' + $ref: '#/components/schemas/ErrorResponse' description: Service Unavailable tags: - Health @@ -75,31 +78,34 @@ paths: $ref: '#/components/schemas/Opportunity' status_code: type: integer + minimum: -2147483648 + maximum: 2147483647 description: The HTTP status code - warnings: - type: array - items: - $ref: '#/components/schemas/ValidationError' - errors: - type: array - items: - $ref: '#/components/schemas/ValidationError' pagination_info: description: The pagination information for paginated endpoints + type: + - object allOf: - $ref: '#/components/schemas/PaginationInfo' + warnings: + type: array + items: + type: + - object + allOf: + - $ref: '#/components/schemas/ValidationIssue' description: Successful response '422': content: application/json: schema: - $ref: '#/components/schemas/ValidationError' + $ref: '#/components/schemas/ErrorResponse' description: Validation error '401': content: application/json: schema: - $ref: '#/components/schemas/HTTPError' + $ref: '#/components/schemas/ErrorResponse' description: Authentication error tags: - Opportunity @@ -133,31 +139,34 @@ paths: $ref: '#/components/schemas/Opportunity' status_code: type: integer + minimum: -2147483648 + maximum: 2147483647 description: The HTTP status code - warnings: - type: array - items: - $ref: '#/components/schemas/ValidationError' - errors: - type: array - items: - $ref: '#/components/schemas/ValidationError' pagination_info: description: The pagination information for paginated endpoints + type: + - object allOf: - $ref: '#/components/schemas/PaginationInfo' + warnings: + type: array + items: + type: + - object + allOf: + - $ref: '#/components/schemas/ValidationIssue' description: Successful response '401': content: application/json: schema: - $ref: '#/components/schemas/HTTPError' + $ref: '#/components/schemas/ErrorResponse' description: Authentication error '404': content: application/json: schema: - $ref: '#/components/schemas/HTTPError' + $ref: '#/components/schemas/ErrorResponse' description: Not found tags: - Opportunity @@ -167,41 +176,31 @@ paths: openapi: 3.1.0 components: schemas: - ValidationError: - type: object - properties: - type: - type: string - description: The type of error - message: - type: string - description: The message to return - rule: - type: string - description: The rule that failed - field: - type: string - description: The field that failed - value: - type: string - description: The value that failed PaginationInfo: type: object properties: page_offset: type: integer + minimum: -2147483648 + maximum: 2147483647 description: The page number that was fetched example: 1 page_size: type: integer + minimum: -2147483648 + maximum: 2147483647 description: The size of the page fetched example: 25 total_records: type: integer + minimum: -2147483648 + maximum: 2147483647 description: The total number of records fetchable example: 42 total_pages: type: integer + minimum: -2147483648 + maximum: 2147483647 description: The total number of pages that can be fetched example: 2 order_by: @@ -213,16 +212,48 @@ components: enum: - ascending - descending - HTTPError: + type: + - string + ValidationIssue: + type: object properties: - detail: - type: object + type: + type: string + description: The type of error message: type: string - type: object + description: The message to return + field: + type: string + description: The field that failed + value: + type: string + description: The value that failed Healthcheck: type: object - properties: {} + properties: + message: + type: string + ErrorResponse: + type: object + properties: + message: + type: string + description: The message to return + data: + description: The REST resource object + status_code: + type: integer + minimum: -2147483648 + maximum: 2147483647 + description: The HTTP status code + errors: + type: array + items: + type: + - object + allOf: + - $ref: '#/components/schemas/ValidationIssue' OpportunitySorting: type: object properties: @@ -240,6 +271,8 @@ components: enum: - ascending - descending + type: + - string required: - order_by - sort_direction @@ -276,14 +309,22 @@ components: - C - E - O + type: + - string is_draft: type: boolean description: Whether to search for draft claims example: false sorting: - $ref: '#/components/schemas/OpportunitySorting' + type: + - object + allOf: + - $ref: '#/components/schemas/OpportunitySorting' paging: - $ref: '#/components/schemas/Pagination' + type: + - object + allOf: + - $ref: '#/components/schemas/Pagination' required: - paging - sorting @@ -293,6 +334,8 @@ components: opportunity_id: type: integer readOnly: true + minimum: -2147483648 + maximum: 2147483647 description: The internal ID of the opportunity example: 12345 opportunity_number: @@ -317,6 +360,8 @@ components: - C - E - O + type: + - string is_draft: type: boolean description: Whether the opportunity is in a draft status From 64f8f51b1cd69e9413809b26d440b6be4c117ba1 Mon Sep 17 00:00:00 2001 From: Michael Chouinard Date: Tue, 14 Nov 2023 10:22:50 -0500 Subject: [PATCH 5/6] Removing int validation default and value in error response --- api/src/api/response.py | 1 - api/src/api/schemas/extension/schema_fields.py | 2 +- api/src/api/schemas/response_schema.py | 1 - api/tests/src/route/test_opportunity_route.py | 10 ---------- api/tests/src/route/test_route_error_format.py | 5 ++--- 5 files changed, 3 insertions(+), 16 deletions(-) diff --git a/api/src/api/response.py b/api/src/api/response.py index 5008e8e6c..0990dc120 100644 --- a/api/src/api/response.py +++ b/api/src/api/response.py @@ -16,7 +16,6 @@ class ValidationErrorDetail: type: str message: str = "" field: Optional[str] = None - value: Optional[str] = None # Do not store PII data here, as it gets logged in some cases class ValidationException(apiflask.exceptions.HTTPError): diff --git a/api/src/api/schemas/extension/schema_fields.py b/api/src/api/schemas/extension/schema_fields.py index 7409ecb40..97b08636d 100644 --- a/api/src/api/schemas/extension/schema_fields.py +++ b/api/src/api/schemas/extension/schema_fields.py @@ -88,7 +88,7 @@ class Integer(original_fields.Integer, MixinField): "invalid": MarshmallowErrorContainer(ValidationErrorType.INVALID, "Not a valid integer."), } - def __init__(self, restrict_to_32bit_int: bool = True, **kwargs: typing.Any): + def __init__(self, restrict_to_32bit_int: bool = False, **kwargs: typing.Any): # By default, we'll restrict all integer values to 32-bits so that they can be stored in # Postgres' integer column. If you wish to process a larger value, simply set this to false or specify # your own min/max Range. diff --git a/api/src/api/schemas/response_schema.py b/api/src/api/schemas/response_schema.py index 9360c0c0e..b6509699b 100644 --- a/api/src/api/schemas/response_schema.py +++ b/api/src/api/schemas/response_schema.py @@ -6,7 +6,6 @@ class ValidationIssueSchema(Schema): type = fields.String(metadata={"description": "The type of error"}) message = fields.String(metadata={"description": "The message to return"}) field = fields.String(metadata={"description": "The field that failed"}) - value = fields.String(metadata={"description": "The value that failed"}) class BaseResponseSchema(Schema): diff --git a/api/tests/src/route/test_opportunity_route.py b/api/tests/src/route/test_opportunity_route.py index 03f62af61..1ae091aa7 100644 --- a/api/tests/src/route/test_opportunity_route.py +++ b/api/tests/src/route/test_opportunity_route.py @@ -270,13 +270,11 @@ def test_opportunity_search_paging_and_sorting_200( "field": "sorting", "message": "Missing data for required field.", "type": "required", - "value": None, }, { "field": "paging", "message": "Missing data for required field.", "type": "required", - "value": None, }, ], ), @@ -287,13 +285,11 @@ def test_opportunity_search_paging_and_sorting_200( "field": "paging.page_size", "message": "Must be greater than or equal to 1.", "type": "min_or_max_value", - "value": None, }, { "field": "paging.page_offset", "message": "Must be greater than or equal to 1.", "type": "min_or_max_value", - "value": None, }, ], ), @@ -304,13 +300,11 @@ def test_opportunity_search_paging_and_sorting_200( "field": "sorting.order_by", "message": "Value must be one of: opportunity_id, agency, opportunity_number, created_at, updated_at", "type": "invalid_choice", - "value": None, }, { "field": "sorting.sort_direction", "message": "Must be one of: ascending, descending.", "type": "invalid_choice", - "value": None, }, ], ), @@ -321,7 +315,6 @@ def test_opportunity_search_paging_and_sorting_200( "field": "opportunity_title", "message": "Not a valid string.", "type": "invalid", - "value": None, } ], ), @@ -332,7 +325,6 @@ def test_opportunity_search_paging_and_sorting_200( "field": "category", "message": "Must be one of: D, M, C, E, O.", "type": "invalid_choice", - "value": None, } ], ), @@ -343,7 +335,6 @@ def test_opportunity_search_paging_and_sorting_200( "field": "is_draft", "message": "Not a valid boolean.", "type": "invalid", - "value": None, } ], ), @@ -398,7 +389,6 @@ def test_opportunity_search_feature_flag_invalid_value_422( "field": "FF-Enable-Opportunity-Log-Msg", "message": "Not a valid boolean.", "type": "invalid", - "value": None, } ] diff --git a/api/tests/src/route/test_route_error_format.py b/api/tests/src/route/test_route_error_format.py index d5683ef19..4eebfda19 100644 --- a/api/tests/src/route/test_route_error_format.py +++ b/api/tests/src/route/test_route_error_format.py @@ -146,10 +146,10 @@ def override(self): None, [ ValidationErrorDetail( - type="example", message="example message", field="example_field", value="value" + type="example", message="example message", field="example_field" ), ValidationErrorDetail( - type="example2", message="example message2", field="example_field2", value=4 + type="example2", message="example message2", field="example_field2" ), ], ), @@ -277,7 +277,6 @@ def test_marshmallow_validation(simple_client, api_auth_token, monkeypatch): "type": error.key, "message": error.message, "field": field.removesuffix("._schema"), - "value": None, } ) From 52eda6d4bfd8dcd3efe43df26a9fc8bfe9f999c6 Mon Sep 17 00:00:00 2001 From: nava-platform-bot Date: Tue, 14 Nov 2023 15:25:14 +0000 Subject: [PATCH 6/6] Update OpenAPI spec --- api/openapi.generated.yml | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/api/openapi.generated.yml b/api/openapi.generated.yml index bd6b54978..945353f1e 100644 --- a/api/openapi.generated.yml +++ b/api/openapi.generated.yml @@ -27,8 +27,6 @@ paths: $ref: '#/components/schemas/Healthcheck' status_code: type: integer - minimum: -2147483648 - maximum: 2147483647 description: The HTTP status code pagination_info: description: The pagination information for paginated endpoints @@ -78,8 +76,6 @@ paths: $ref: '#/components/schemas/Opportunity' status_code: type: integer - minimum: -2147483648 - maximum: 2147483647 description: The HTTP status code pagination_info: description: The pagination information for paginated endpoints @@ -139,8 +135,6 @@ paths: $ref: '#/components/schemas/Opportunity' status_code: type: integer - minimum: -2147483648 - maximum: 2147483647 description: The HTTP status code pagination_info: description: The pagination information for paginated endpoints @@ -181,26 +175,18 @@ components: properties: page_offset: type: integer - minimum: -2147483648 - maximum: 2147483647 description: The page number that was fetched example: 1 page_size: type: integer - minimum: -2147483648 - maximum: 2147483647 description: The size of the page fetched example: 25 total_records: type: integer - minimum: -2147483648 - maximum: 2147483647 description: The total number of records fetchable example: 42 total_pages: type: integer - minimum: -2147483648 - maximum: 2147483647 description: The total number of pages that can be fetched example: 2 order_by: @@ -226,9 +212,6 @@ components: field: type: string description: The field that failed - value: - type: string - description: The value that failed Healthcheck: type: object properties: @@ -244,8 +227,6 @@ components: description: The REST resource object status_code: type: integer - minimum: -2147483648 - maximum: 2147483647 description: The HTTP status code errors: type: array @@ -334,8 +315,6 @@ components: opportunity_id: type: integer readOnly: true - minimum: -2147483648 - maximum: 2147483647 description: The internal ID of the opportunity example: 12345 opportunity_number: