Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tin/better union hooks #499

Merged
merged 3 commits into from
Feb 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 14 additions & 20 deletions src/cattrs/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
IterableValidationNote,
StructureHandlerNotFoundError,
)
from .fns import identity, raise_error
from .fns import Predicate, identity, raise_error
from .gen import (
AttributeOverride,
DictStructureFn,
Expand Down Expand Up @@ -174,6 +174,7 @@ def __init__(
self._prefer_attrib_converters = prefer_attrib_converters

self.detailed_validation = detailed_validation
self._union_struct_registry: dict[Any, Callable[[Any, type[T]], T]] = {}

# Create a per-instance cache.
if unstruct_strat is UnstructureStrategy.AS_DICT:
Expand Down Expand Up @@ -246,7 +247,8 @@ def __init__(
(is_supported_union, self._gen_attrs_union_structure, True),
(
lambda t: is_union_type(t) and t in self._union_struct_registry,
self._structure_union,
self._union_struct_registry.__getitem__,
True,
),
(is_optional, self._structure_optional),
(has, self._structure_attrs),
Expand All @@ -266,9 +268,6 @@ def __init__(

self._dict_factory = dict_factory

# Unions are instances now, not classes. We use different registries.
self._union_struct_registry: dict[Any, Callable[[Any, type[T]], T]] = {}

self._unstruct_copy_skip = self._unstructure_func.get_num_fns()
self._struct_copy_skip = self._structure_func.get_num_fns()

Expand Down Expand Up @@ -330,7 +329,7 @@ def register_unstructure_hook(
return None

def register_unstructure_hook_func(
self, check_func: Callable[[Any], bool], func: UnstructureHook
self, check_func: Predicate, func: UnstructureHook
) -> None:
"""Register a class-to-primitive converter function for a class, using
a function to check if it's a match.
Expand All @@ -339,25 +338,25 @@ def register_unstructure_hook_func(

@overload
def register_unstructure_hook_factory(
self, predicate: Callable[[Any], bool]
self, predicate: Predicate
) -> Callable[[UnstructureHookFactory], UnstructureHookFactory]:
...

@overload
def register_unstructure_hook_factory(
self, predicate: Callable[[Any], bool]
self, predicate: Predicate
) -> Callable[[ExtendedUnstructureHookFactory], ExtendedUnstructureHookFactory]:
...

@overload
def register_unstructure_hook_factory(
self, predicate: Callable[[Any], bool], factory: UnstructureHookFactory
self, predicate: Predicate, factory: UnstructureHookFactory
) -> UnstructureHookFactory:
...

@overload
def register_unstructure_hook_factory(
self, predicate: Callable[[Any], bool], factory: ExtendedUnstructureHookFactory
self, predicate: Predicate, factory: ExtendedUnstructureHookFactory
) -> ExtendedUnstructureHookFactory:
...

Expand Down Expand Up @@ -473,7 +472,7 @@ def register_structure_hook(
self._structure_func.register_cls_list([(cl, func)])

def register_structure_hook_func(
self, check_func: Callable[[type[T]], bool], func: StructureHook
self, check_func: Predicate, func: StructureHook
) -> None:
"""Register a class-to-primitive converter function for a class, using
a function to check if it's a match.
Expand All @@ -482,25 +481,25 @@ def register_structure_hook_func(

@overload
def register_structure_hook_factory(
self, predicate: Callable[[Any, bool]]
self, predicate: Predicate
) -> Callable[[StructureHookFactory, StructureHookFactory]]:
...

@overload
def register_structure_hook_factory(
self, predicate: Callable[[Any, bool]]
self, predicate: Predicate
) -> Callable[[ExtendedStructureHookFactory, ExtendedStructureHookFactory]]:
...

@overload
def register_structure_hook_factory(
self, predicate: Callable[[Any], bool], factory: StructureHookFactory
self, predicate: Predicate, factory: StructureHookFactory
) -> StructureHookFactory:
...

@overload
def register_structure_hook_factory(
self, predicate: Callable[[Any], bool], factory: ExtendedStructureHookFactory
self, predicate: Predicate, factory: ExtendedStructureHookFactory
) -> ExtendedStructureHookFactory:
...

Expand Down Expand Up @@ -903,11 +902,6 @@ def _structure_optional(self, obj, union):
# We can't actually have a Union of a Union, so this is safe.
return self._structure_func.dispatch(other)(obj, other)

def _structure_union(self, obj, union):
"""Deal with structuring a union."""
handler = self._union_struct_registry[union]
return handler(obj, union)

def _structure_tuple(self, obj: Any, tup: type[T]) -> T:
"""Deal with structuring into a tuple."""
tup_params = None if tup in (Tuple, tuple) else tup.__args__
Expand Down
17 changes: 6 additions & 11 deletions src/cattrs/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
from attrs import Factory, define

from ._compat import TypeAlias
from .fns import Predicate

if TYPE_CHECKING:
from .converters import BaseConverter

T = TypeVar("T")

TargetType: TypeAlias = Any
UnstructuredValue: TypeAlias = Any
StructuredValue: TypeAlias = Any
Expand Down Expand Up @@ -46,12 +45,12 @@ class FunctionDispatch:

_converter: BaseConverter
_handler_pairs: list[
tuple[Callable[[Any], bool], Callable[[Any, Any], Any], bool, bool]
tuple[Predicate, Callable[[Any, Any], Any], bool, bool]
] = Factory(list)

def register(
self,
predicate: Callable[[Any], bool],
predicate: Predicate,
func: Callable[..., Any],
is_generator=False,
takes_converter=False,
Expand Down Expand Up @@ -148,13 +147,9 @@ def register_cls_list(self, cls_and_handler, direct: bool = False) -> None:
def register_func_list(
self,
pred_and_handler: list[
tuple[Callable[[Any], bool], Any]
| tuple[Callable[[Any], bool], Any, bool]
| tuple[
Callable[[Any], bool],
Callable[[Any, BaseConverter], Any],
Literal["extended"],
]
tuple[Predicate, Any]
| tuple[Predicate, Any, bool]
| tuple[Predicate, Callable[[Any, BaseConverter], Any], Literal["extended"]]
],
):
"""
Expand Down
6 changes: 5 additions & 1 deletion src/cattrs/fns.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"""Useful internal functions."""
from typing import NoReturn, Type, TypeVar
from typing import Any, Callable, NoReturn, Type, TypeVar

from ._compat import TypeAlias
from .errors import StructureHandlerNotFoundError

T = TypeVar("T")

Predicate: TypeAlias = Callable[[Any], bool]
"""A predicate function determines if a type can be handled."""


def identity(obj: T) -> T:
"""The identity function."""
Expand Down
14 changes: 3 additions & 11 deletions src/cattrs/gen/typeddicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,11 +565,9 @@ def _required_keys(cls: type) -> set[str]:
# gathering required keys. *sigh*
own_annotations = cls.__dict__.get("__annotations__", {})
required_keys = set()
for base in cls.__mro__[1:]:
if base in (object, dict):
# These have no required keys for sure.
continue
required_keys |= _required_keys(base)
# On 3.8 - 3.10, typing.TypedDict doesn't put typeddict superclasses
# in the MRO, therefore we cannot handle non-required keys properly
# in some situations. Oh well.
for key in getattr(cls, "__required_keys__", []):
annotation_type = own_annotations[key]
annotation_origin = get_origin(annotation_type)
Expand Down Expand Up @@ -597,13 +595,7 @@ def _required_keys(cls: type) -> set[str]:

own_annotations = cls.__dict__.get("__annotations__", {})
required_keys = set()
superclass_keys = set()
for base in cls.__mro__[1:]:
required_keys |= _required_keys(base)
superclass_keys |= base.__dict__.get("__annotations__", {}).keys()
for key in own_annotations:
if key in superclass_keys:
continue
annotation_type = own_annotations[key]

if is_annotated(annotation_type):
Expand Down
2 changes: 2 additions & 0 deletions tests/_compat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import sys

is_py38 = sys.version_info[:2] == (3, 8)
is_py39 = sys.version_info[:2] == (3, 9)
is_py39_plus = sys.version_info >= (3, 9)
is_py310 = sys.version_info[:2] == (3, 10)
is_py310_plus = sys.version_info >= (3, 10)
is_py311_plus = sys.version_info >= (3, 11)
is_py312_plus = sys.version_info >= (3, 12)
Expand Down
26 changes: 24 additions & 2 deletions tests/test_typeddicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from hypothesis import assume, given
from hypothesis.strategies import booleans
from pytest import raises
from typing_extensions import NotRequired
from typing_extensions import NotRequired, Required

from cattrs import BaseConverter, Converter
from cattrs._compat import ExtensionsTypedDict, get_notrequired_base, is_generic
Expand All @@ -24,7 +24,7 @@
make_dict_unstructure_fn,
)

from ._compat import is_py38, is_py311_plus
from ._compat import is_py38, is_py39, is_py310, is_py311_plus
from .typeddicts import (
generic_typeddicts,
simple_typeddicts,
Expand Down Expand Up @@ -263,6 +263,28 @@ def test_required(
assert restructured == instance


@pytest.mark.skipif(is_py39 or is_py310, reason="Sigh")
def test_required_keys() -> None:
"""We don't support the full gamut of functionality on 3.8.

When using `typing.TypedDict` we have only partial functionality;
this test tests only a subset of this.
"""
c = mk_converter()

class Base(TypedDict, total=False):
a: Required[datetime]

class Sub(Base):
b: int

fn = make_dict_unstructure_fn(Sub, c)

with raises(KeyError):
# This needs to raise since 'a' is missing, and it's Required.
fn({"b": 1})


@given(simple_typeddicts(min_attrs=1, total=True), booleans())
def test_omit(cls_and_instance: Tuple[type, Dict], detailed_validation: bool) -> None:
"""`override(omit=True)` works."""
Expand Down
2 changes: 1 addition & 1 deletion tests/typeddicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def simple_typeddicts(
note(
"\n".join(
[
"class HypTypedDict(TypedDict):",
f"class HypTypedDict(TypedDict{'' if total else ', total=False'}):",
*[f" {n}: {a}" for n, a in attrs_dict.items()],
]
)
Expand Down
Loading