diff --git a/src/cattrs/converters.py b/src/cattrs/converters.py index 16bee1dd..441b8c2c 100644 --- a/src/cattrs/converters.py +++ b/src/cattrs/converters.py @@ -65,7 +65,7 @@ IterableValidationNote, StructureHandlerNotFoundError, ) -from .fns import identity, raise_error +from .fns import Predicate, identity, raise_error from .gen import ( AttributeOverride, DictStructureFn, @@ -174,6 +174,7 @@ def __init__( self._prefer_attrib_converters = prefer_attrib_converters self.detailed_validation = detailed_validation + self._union_struct_registry: dict[Any, Callable[[Any, type[T]], T]] = {} # Create a per-instance cache. if unstruct_strat is UnstructureStrategy.AS_DICT: @@ -246,7 +247,8 @@ def __init__( (is_supported_union, self._gen_attrs_union_structure, True), ( lambda t: is_union_type(t) and t in self._union_struct_registry, - self._structure_union, + self._union_struct_registry.__getitem__, + True, ), (is_optional, self._structure_optional), (has, self._structure_attrs), @@ -266,9 +268,6 @@ def __init__( self._dict_factory = dict_factory - # Unions are instances now, not classes. We use different registries. - self._union_struct_registry: dict[Any, Callable[[Any, type[T]], T]] = {} - self._unstruct_copy_skip = self._unstructure_func.get_num_fns() self._struct_copy_skip = self._structure_func.get_num_fns() @@ -330,7 +329,7 @@ def register_unstructure_hook( return None def register_unstructure_hook_func( - self, check_func: Callable[[Any], bool], func: UnstructureHook + self, check_func: Predicate, func: UnstructureHook ) -> None: """Register a class-to-primitive converter function for a class, using a function to check if it's a match. @@ -339,25 +338,25 @@ def register_unstructure_hook_func( @overload def register_unstructure_hook_factory( - self, predicate: Callable[[Any], bool] + self, predicate: Predicate ) -> Callable[[UnstructureHookFactory], UnstructureHookFactory]: ... @overload def register_unstructure_hook_factory( - self, predicate: Callable[[Any], bool] + self, predicate: Predicate ) -> Callable[[ExtendedUnstructureHookFactory], ExtendedUnstructureHookFactory]: ... @overload def register_unstructure_hook_factory( - self, predicate: Callable[[Any], bool], factory: UnstructureHookFactory + self, predicate: Predicate, factory: UnstructureHookFactory ) -> UnstructureHookFactory: ... @overload def register_unstructure_hook_factory( - self, predicate: Callable[[Any], bool], factory: ExtendedUnstructureHookFactory + self, predicate: Predicate, factory: ExtendedUnstructureHookFactory ) -> ExtendedUnstructureHookFactory: ... @@ -473,7 +472,7 @@ def register_structure_hook( self._structure_func.register_cls_list([(cl, func)]) def register_structure_hook_func( - self, check_func: Callable[[type[T]], bool], func: StructureHook + self, check_func: Predicate, func: StructureHook ) -> None: """Register a class-to-primitive converter function for a class, using a function to check if it's a match. @@ -482,25 +481,25 @@ def register_structure_hook_func( @overload def register_structure_hook_factory( - self, predicate: Callable[[Any, bool]] + self, predicate: Predicate ) -> Callable[[StructureHookFactory, StructureHookFactory]]: ... @overload def register_structure_hook_factory( - self, predicate: Callable[[Any, bool]] + self, predicate: Predicate ) -> Callable[[ExtendedStructureHookFactory, ExtendedStructureHookFactory]]: ... @overload def register_structure_hook_factory( - self, predicate: Callable[[Any], bool], factory: StructureHookFactory + self, predicate: Predicate, factory: StructureHookFactory ) -> StructureHookFactory: ... @overload def register_structure_hook_factory( - self, predicate: Callable[[Any], bool], factory: ExtendedStructureHookFactory + self, predicate: Predicate, factory: ExtendedStructureHookFactory ) -> ExtendedStructureHookFactory: ... @@ -903,11 +902,6 @@ def _structure_optional(self, obj, union): # We can't actually have a Union of a Union, so this is safe. return self._structure_func.dispatch(other)(obj, other) - def _structure_union(self, obj, union): - """Deal with structuring a union.""" - handler = self._union_struct_registry[union] - return handler(obj, union) - def _structure_tuple(self, obj: Any, tup: type[T]) -> T: """Deal with structuring into a tuple.""" tup_params = None if tup in (Tuple, tuple) else tup.__args__ diff --git a/src/cattrs/dispatch.py b/src/cattrs/dispatch.py index 792b613f..f82ae878 100644 --- a/src/cattrs/dispatch.py +++ b/src/cattrs/dispatch.py @@ -6,12 +6,11 @@ from attrs import Factory, define from ._compat import TypeAlias +from .fns import Predicate if TYPE_CHECKING: from .converters import BaseConverter -T = TypeVar("T") - TargetType: TypeAlias = Any UnstructuredValue: TypeAlias = Any StructuredValue: TypeAlias = Any @@ -46,12 +45,12 @@ class FunctionDispatch: _converter: BaseConverter _handler_pairs: list[ - tuple[Callable[[Any], bool], Callable[[Any, Any], Any], bool, bool] + tuple[Predicate, Callable[[Any, Any], Any], bool, bool] ] = Factory(list) def register( self, - predicate: Callable[[Any], bool], + predicate: Predicate, func: Callable[..., Any], is_generator=False, takes_converter=False, @@ -148,13 +147,9 @@ def register_cls_list(self, cls_and_handler, direct: bool = False) -> None: def register_func_list( self, pred_and_handler: list[ - tuple[Callable[[Any], bool], Any] - | tuple[Callable[[Any], bool], Any, bool] - | tuple[ - Callable[[Any], bool], - Callable[[Any, BaseConverter], Any], - Literal["extended"], - ] + tuple[Predicate, Any] + | tuple[Predicate, Any, bool] + | tuple[Predicate, Callable[[Any, BaseConverter], Any], Literal["extended"]] ], ): """ diff --git a/src/cattrs/fns.py b/src/cattrs/fns.py index 43d0ab0d..7d3db677 100644 --- a/src/cattrs/fns.py +++ b/src/cattrs/fns.py @@ -1,10 +1,14 @@ """Useful internal functions.""" -from typing import NoReturn, Type, TypeVar +from typing import Any, Callable, NoReturn, Type, TypeVar +from ._compat import TypeAlias from .errors import StructureHandlerNotFoundError T = TypeVar("T") +Predicate: TypeAlias = Callable[[Any], bool] +"""A predicate function determines if a type can be handled.""" + def identity(obj: T) -> T: """The identity function.""" diff --git a/src/cattrs/gen/typeddicts.py b/src/cattrs/gen/typeddicts.py index c8a6e619..13decdaa 100644 --- a/src/cattrs/gen/typeddicts.py +++ b/src/cattrs/gen/typeddicts.py @@ -565,11 +565,9 @@ def _required_keys(cls: type) -> set[str]: # gathering required keys. *sigh* own_annotations = cls.__dict__.get("__annotations__", {}) required_keys = set() - for base in cls.__mro__[1:]: - if base in (object, dict): - # These have no required keys for sure. - continue - required_keys |= _required_keys(base) + # On 3.8 - 3.10, typing.TypedDict doesn't put typeddict superclasses + # in the MRO, therefore we cannot handle non-required keys properly + # in some situations. Oh well. for key in getattr(cls, "__required_keys__", []): annotation_type = own_annotations[key] annotation_origin = get_origin(annotation_type) @@ -597,13 +595,7 @@ def _required_keys(cls: type) -> set[str]: own_annotations = cls.__dict__.get("__annotations__", {}) required_keys = set() - superclass_keys = set() - for base in cls.__mro__[1:]: - required_keys |= _required_keys(base) - superclass_keys |= base.__dict__.get("__annotations__", {}).keys() for key in own_annotations: - if key in superclass_keys: - continue annotation_type = own_annotations[key] if is_annotated(annotation_type): diff --git a/tests/_compat.py b/tests/_compat.py index 1636df0d..dba215bd 100644 --- a/tests/_compat.py +++ b/tests/_compat.py @@ -1,7 +1,9 @@ import sys is_py38 = sys.version_info[:2] == (3, 8) +is_py39 = sys.version_info[:2] == (3, 9) is_py39_plus = sys.version_info >= (3, 9) +is_py310 = sys.version_info[:2] == (3, 10) is_py310_plus = sys.version_info >= (3, 10) is_py311_plus = sys.version_info >= (3, 11) is_py312_plus = sys.version_info >= (3, 12) diff --git a/tests/test_typeddicts.py b/tests/test_typeddicts.py index de5d9b59..82f1a3c4 100644 --- a/tests/test_typeddicts.py +++ b/tests/test_typeddicts.py @@ -7,7 +7,7 @@ from hypothesis import assume, given from hypothesis.strategies import booleans from pytest import raises -from typing_extensions import NotRequired +from typing_extensions import NotRequired, Required from cattrs import BaseConverter, Converter from cattrs._compat import ExtensionsTypedDict, get_notrequired_base, is_generic @@ -24,7 +24,7 @@ make_dict_unstructure_fn, ) -from ._compat import is_py38, is_py311_plus +from ._compat import is_py38, is_py39, is_py310, is_py311_plus from .typeddicts import ( generic_typeddicts, simple_typeddicts, @@ -263,6 +263,28 @@ def test_required( assert restructured == instance +@pytest.mark.skipif(is_py39 or is_py310, reason="Sigh") +def test_required_keys() -> None: + """We don't support the full gamut of functionality on 3.8. + + When using `typing.TypedDict` we have only partial functionality; + this test tests only a subset of this. + """ + c = mk_converter() + + class Base(TypedDict, total=False): + a: Required[datetime] + + class Sub(Base): + b: int + + fn = make_dict_unstructure_fn(Sub, c) + + with raises(KeyError): + # This needs to raise since 'a' is missing, and it's Required. + fn({"b": 1}) + + @given(simple_typeddicts(min_attrs=1, total=True), booleans()) def test_omit(cls_and_instance: Tuple[type, Dict], detailed_validation: bool) -> None: """`override(omit=True)` works.""" diff --git a/tests/typeddicts.py b/tests/typeddicts.py index 048a5ae2..6e00f07b 100644 --- a/tests/typeddicts.py +++ b/tests/typeddicts.py @@ -180,7 +180,7 @@ def simple_typeddicts( note( "\n".join( [ - "class HypTypedDict(TypedDict):", + f"class HypTypedDict(TypedDict{'' if total else ', total=False'}):", *[f" {n}: {a}" for n, a in attrs_dict.items()], ] )