From 6ff74e7baf4ee5a328925a2c473d9dbb3972fe52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kacper=20Kali=C5=84ski?= Date: Thu, 12 Dec 2024 11:50:47 +0100 Subject: [PATCH] Allow TypedDict in State --- pyproject.toml | 2 +- src/haiway/state/attributes.py | 461 +++++++++++++++++++-------------- src/haiway/state/structure.py | 5 +- src/haiway/state/validation.py | 59 ++++- tests/test_state.py | 24 +- 5 files changed, 348 insertions(+), 203 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fc2d979..8b1fb32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" [project] name = "haiway" description = "Framework for dependency injection and state management within structured concurrency model." -version = "0.6.4" +version = "0.6.5" readme = "README.md" maintainers = [ { name = "Kacper KaliƄski", email = "kacper.kalinski@miquido.com" }, diff --git a/src/haiway/state/attributes.py b/src/haiway/state/attributes.py index 555911f..75e59af 100644 --- a/src/haiway/state/attributes.py +++ b/src/haiway/state/attributes.py @@ -1,7 +1,6 @@ -import sys import types import typing -from collections.abc import Mapping +from collections.abc import Mapping, MutableMapping, Sequence from types import NoneType, UnionType from typing import ( Any, @@ -9,13 +8,17 @@ ForwardRef, Generic, Literal, - TypeAliasType, + Self, TypeVar, get_args, get_origin, get_type_hints, + is_typeddict, ) +from haiway import types as haiway_types +from haiway.types import MISSING + __all__ = [ "AttributeAnnotation", "attribute_annotations", @@ -28,10 +31,23 @@ def __init__( self, *, origin: Any, - arguments: list[Any], + arguments: Sequence[Any] | None = None, + required: bool = True, + extra: Mapping[str, Any] | None = None, ) -> None: self.origin: Any = origin - self.arguments: list[Any] = arguments + self.arguments: Sequence[Any] = arguments or () + self.required: bool = required + self.extra: Mapping[str, Any] = extra or {} + + def update_required( + self, + required: bool, + /, + ) -> Self: + self.required = self.required and required + + return self def __eq__( self, @@ -52,103 +68,154 @@ def __str__(self) -> str: def attribute_annotations( cls: type[Any], /, - type_parameters: dict[str, Any] | None = None, -) -> dict[str, AttributeAnnotation]: + type_parameters: Mapping[str, Any] | None = None, +) -> Mapping[str, AttributeAnnotation]: type_parameters = type_parameters or {} self_annotation = AttributeAnnotation( origin=cls, arguments=[], # ignore self arguments here, State will have them resolved at this stage ) - localns: dict[str, Any] = {cls.__name__: cls} - recursion_guard: dict[Any, AttributeAnnotation] = {cls: self_annotation} attributes: dict[str, AttributeAnnotation] = {} - for key, annotation in get_type_hints(cls, localns=localns).items(): + for key, annotation in get_type_hints(cls, localns={cls.__name__: cls}).items(): # do not include ClassVars, private or dunder items if ((get_origin(annotation) or annotation) is ClassVar) or key.startswith("_"): continue attributes[key] = resolve_attribute_annotation( annotation, - self_annotation=self_annotation, type_parameters=type_parameters, - module=cls.__module__, - localns=localns, - recursion_guard=recursion_guard, + module=getattr(annotation, "__module__", cls.__module__), + localns={cls.__name__: cls}, + recursion_guard={"Self": self_annotation, cls.__name__: self_annotation}, ) return attributes -def resolve_attribute_annotation( # noqa: C901, PLR0911, PLR0912, PLR0913 +def resolve_attribute_annotation( # noqa: C901, PLR0912, PLR0915 annotation: Any, /, - self_annotation: AttributeAnnotation | None, - type_parameters: dict[str, Any], + type_parameters: Mapping[str, Any], module: str, - localns: dict[str, Any], - recursion_guard: Mapping[Any, AttributeAnnotation], # TODO: verify recursion! + localns: Mapping[str, Any], + recursion_guard: MutableMapping[str, AttributeAnnotation], # TODO: verify recursion! ) -> AttributeAnnotation: + resolved_attribute = AttributeAnnotation(origin=MISSING) + if recursion_name := getattr( + annotation, + "__qualname__", + getattr( + annotation, + "__name__", + None, + ), + ): + if recursive := recursion_guard.get(recursion_name): + return recursive + + else: + recursion_guard[recursion_name] = resolved_attribute + # resolve annotation directly if able match annotation: # None case types.NoneType | types.NoneType(): - return AttributeAnnotation( - origin=NoneType, - arguments=[], - ) + resolved_attribute.origin = NoneType # forward reference through string case str() as forward_ref: - return resolve_attribute_annotation( - ForwardRef(forward_ref, module=module)._evaluate( - globalns=None, - localns=localns, - recursive_guard=frozenset(), + evaluated: Any = ForwardRef(forward_ref, module=module)._evaluate( + globalns=None, + localns=localns, + recursive_guard=frozenset(), + ) + + if recursion_name := getattr( + evaluated, + "__qualname__", + getattr( + evaluated, + "__name__", + None, ), - self_annotation=self_annotation, + ): + if recursive := recursion_guard.get(recursion_name): + return recursive + + else: + recursion_guard[recursion_name] = resolved_attribute + + resolved: AttributeAnnotation = resolve_attribute_annotation( + evaluated, type_parameters=type_parameters, - module=module, - localns=localns, - recursion_guard=recursion_guard, # we might need to update it somehow? + module=evaluated.__module__ or module, + localns={**localns, getattr(evaluated, "__name__", str(evaluated)): evaluated}, + recursion_guard=recursion_guard, ) + resolved_attribute.origin = resolved.origin + resolved_attribute.arguments = resolved.arguments + resolved_attribute.extra = resolved.extra + resolved_attribute.required = resolved.required # forward reference directly case typing.ForwardRef() as reference: - return resolve_attribute_annotation( - reference._evaluate( - globalns=None, - localns=localns, - recursive_guard=frozenset(), + del recursion_guard[annotation.__qualname__] # clean assotiation + evaluated: Any = reference._evaluate( + globalns=None, + localns=localns, + recursive_guard=frozenset(), + ) + + if recursion_name := getattr( + evaluated, + "__qualname__", + getattr( + evaluated, + "__name__", + None, ), - self_annotation=self_annotation, + ): + if recursive := recursion_guard.get(recursion_name): + return recursive + + else: + recursion_guard[recursion_name] = resolved_attribute + + resolved: AttributeAnnotation = resolve_attribute_annotation( + evaluated, type_parameters=type_parameters, - module=module, - localns=localns, - recursion_guard=recursion_guard, # we might need to update it somehow? + module=evaluated.__module__ or module, + localns={**localns, getattr(evaluated, "__name__", str(evaluated)): evaluated}, + recursion_guard=recursion_guard, ) + resolved_attribute.origin = resolved.origin + resolved_attribute.arguments = resolved.arguments + resolved_attribute.extra = resolved.extra + resolved_attribute.required = resolved.required + # generic alias aka parametrized type case types.GenericAlias() as generic_alias: + del recursion_guard[annotation.__qualname__] # clean assotiation match get_origin(generic_alias): # check for an alias with parameters case typing.TypeAliasType() as alias: # pyright: ignore[reportUnnecessaryComparison] - type_alias: AttributeAnnotation = AttributeAnnotation( - origin=TypeAliasType, - arguments=[], + resolved_attribute.origin = NotImplementedError( + f"Unresolved alias annotation: {annotation}" ) resolved: AttributeAnnotation = resolve_attribute_annotation( alias.__value__, - self_annotation=None, type_parameters=type_parameters, module=module, localns=localns, recursion_guard=recursion_guard, ) - type_alias.origin = resolved.origin - type_alias.arguments = resolved.arguments - return type_alias + resolved_attribute.origin = resolved.origin + resolved_attribute.arguments = resolved.arguments + resolved_attribute.extra = resolved.extra + resolved_attribute.required = resolved.required # check if we can resolve it as generic case parametrized if issubclass(parametrized, Generic): @@ -167,200 +234,210 @@ def resolve_attribute_annotation( # noqa: C901, PLR0911, PLR0912, PLR0913 match parametrized_type: # verify if we got any specific type or generic alias again case types.GenericAlias(): - return AttributeAnnotation( - origin=parametrized, - arguments=[ - resolve_attribute_annotation( - argument, - self_annotation=self_annotation, - type_parameters=type_parameters, - module=module, - localns=localns, - recursion_guard=recursion_guard, - ) - for argument in get_args(generic_alias) - ], - ) + resolved_attribute.origin = parametrized + resolved_attribute.arguments = [ + resolve_attribute_annotation( + argument, + type_parameters=type_parameters, + module=module, + localns=localns, + recursion_guard=recursion_guard, + ) + for argument in get_args(generic_alias) + ] # use resolved type if it is not an alias again case _: - return AttributeAnnotation( - origin=parametrized_type, - arguments=[], - ) + resolved_attribute.origin = parametrized_type # anything else - try to resolve a concrete type or use as is case origin: - return AttributeAnnotation( - origin=origin, - arguments=[ - resolve_attribute_annotation( - argument, - self_annotation=self_annotation, - type_parameters=type_parameters, - module=module, - localns=localns, - recursion_guard=recursion_guard, - ) - for argument in get_args(generic_alias) - ], - ) + resolved_attribute.origin = origin + resolved_attribute.arguments = [ + resolve_attribute_annotation( + argument, + type_parameters=type_parameters, + module=module, + localns=localns, + recursion_guard=recursion_guard, + ) + for argument in get_args(generic_alias) + ] # type alias case typing.TypeAliasType() as alias: - type_alias: AttributeAnnotation = AttributeAnnotation( - origin=TypeAliasType, - arguments=[], - ) + del recursion_guard[annotation.__qualname__] # clean assotiation resolved: AttributeAnnotation = resolve_attribute_annotation( alias.__value__, - self_annotation=None, type_parameters=type_parameters, module=module, - localns=localns, + localns={**localns, alias.__name__: alias}, recursion_guard=recursion_guard, ) - type_alias.origin = resolved.origin - type_alias.arguments = resolved.arguments - return type_alias + resolved_attribute.origin = resolved.origin + resolved_attribute.arguments = resolved.arguments + resolved_attribute.extra = resolved.extra + resolved_attribute.required = resolved.required # type parameter case typing.TypeVar(): - return resolve_attribute_annotation( + del recursion_guard[annotation.__qualname__] # clean assotiation + resolved: AttributeAnnotation = resolve_attribute_annotation( # try to resolve it from current parameters if able type_parameters.get( annotation.__name__, # use bound as default or Any otherwise annotation.__bound__ or Any, ), - self_annotation=None, type_parameters=type_parameters, module=module, localns=localns, recursion_guard=recursion_guard, ) + resolved_attribute.origin = resolved.origin + resolved_attribute.arguments = resolved.arguments + resolved_attribute.extra = resolved.extra + resolved_attribute.required = resolved.required case typing.ParamSpec(): - sys.stderr.write( - "ParamSpec is not supported for attribute annotations," - " ignoring with Any type - it might incorrectly validate types\n" - ) - return AttributeAnnotation( - origin=Any, - arguments=[], - ) + raise NotImplementedError(f"Unresolved ParamSpec annotation: {annotation}") case typing.TypeVarTuple(): - sys.stderr.write( - "TypeVarTuple is not supported for attribute annotations," - " ignoring with Any type - it might incorrectly validate types\n" - ) - return AttributeAnnotation( - origin=Any, - arguments=[], - ) - - case _: - pass # proceed to resolving based on origin - - # resolve based on origin if any - match get_origin(annotation) or annotation: - case types.UnionType | typing.Union: - return AttributeAnnotation( - origin=UnionType, # pyright: ignore[reportArgumentType] - arguments=[ - recursion_guard.get( - argument, + raise NotImplementedError(f"Unresolved TypeVarTuple annotation: {annotation}") + + case _: # proceed to resolving based on origin + match get_origin(annotation) or annotation: + case types.UnionType | typing.Union: + del recursion_guard[annotation.__qualname__] # clean assotiation + resolved_attribute.origin = UnionType # pyright: ignore[reportArgumentType] + resolved_attribute.arguments = [ resolve_attribute_annotation( argument, - self_annotation=self_annotation, type_parameters=type_parameters, module=module, localns=localns, recursion_guard=recursion_guard, - ), - ) - for argument in get_args(annotation) - ], - ) + ) + for argument in get_args(annotation) + ] - case typing.Callable: # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue] - return AttributeAnnotation( - origin=typing.Callable, - arguments=[ - resolve_attribute_annotation( - argument, - self_annotation=self_annotation, + case typing.Callable: # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue] + del recursion_guard[annotation.__qualname__] # clean assotiation + resolved_attribute.origin = typing.Callable + resolved_attribute.arguments = [ + resolve_attribute_annotation( + argument, + type_parameters=type_parameters, + module=module, + localns=localns, + recursion_guard=recursion_guard, + ) + for argument in get_args(annotation) + ] + + case typing.Self: # pyright: ignore[reportUnknownMemberType] + if self_annotation := recursion_guard.get("Self"): + resolved_attribute.origin = self_annotation.origin + resolved_attribute.arguments = self_annotation.arguments + resolved_attribute.extra = self_annotation.extra + # skipping requirement as it might be different + + else: + raise NotImplementedError(f"Unresolved Self annotation: {annotation}") + + # unwrap from irrelevant type wrappers + case typing.Annotated | typing.Final | typing.Required: + del recursion_guard[annotation.__qualname__] # clean assotiation + resolved: AttributeAnnotation = resolve_attribute_annotation( + get_args(annotation)[0], type_parameters=type_parameters, module=module, localns=localns, recursion_guard=recursion_guard, ) - for argument in get_args(annotation) - ], - ) + resolved_attribute.origin = resolved.origin + resolved_attribute.arguments = resolved.arguments + resolved_attribute.extra = resolved.extra + resolved_attribute.required = resolved.required - case typing.Self: # pyright: ignore[reportUnknownMemberType] - if not self_annotation: - sys.stderr.write( - "Unresolved Self attribute annotation," - " ignoring with Any type - it might incorrectly validate types\n" - ) - return AttributeAnnotation( - origin=Any, - arguments=[], - ) - - return self_annotation - - # unwrap from irrelevant type wrappers - case typing.Annotated | typing.Final | typing.Required | typing.NotRequired: - return resolve_attribute_annotation( - get_args(annotation)[0], - self_annotation=self_annotation, - type_parameters=type_parameters, - module=module, - localns=localns, - recursion_guard=recursion_guard, - ) - - case typing.Optional: # optional is a Union[Value, None] - return AttributeAnnotation( - origin=UnionType, # pyright: ignore[reportArgumentType] - arguments=[ - resolve_attribute_annotation( + case typing.NotRequired: + del recursion_guard[annotation.__qualname__] # clean assotiation + resolved: AttributeAnnotation = resolve_attribute_annotation( get_args(annotation)[0], - self_annotation=self_annotation, type_parameters=type_parameters, module=module, localns=localns, recursion_guard=recursion_guard, - ), - AttributeAnnotation( - origin=NoneType, - arguments=[], - ), - ], - ) + ) + resolved_attribute.origin = resolved.origin + resolved_attribute.arguments = resolved.arguments + resolved_attribute.extra = resolved.extra + resolved_attribute.required = False + + case typing.Optional: # optional is a Union[Value, None] + del recursion_guard[annotation.__qualname__] # clean assotiation + resolved_attribute.origin = UnionType # pyright: ignore[reportArgumentType] + resolved_attribute.arguments = [ + resolve_attribute_annotation( + get_args(annotation)[0], + type_parameters=type_parameters, + module=module, + localns=localns, + recursion_guard=recursion_guard, + ), + AttributeAnnotation( + origin=NoneType, + ), + ] - case typing.Literal: - return AttributeAnnotation( - origin=Literal, - arguments=list(get_args(annotation)), - ) + case typing.Literal: + del recursion_guard[annotation.__qualname__] # clean assotiation + resolved_attribute.origin = Literal + resolved_attribute.arguments = get_args(annotation) - case other: # finally use whatever there was - return AttributeAnnotation( - origin=other, - arguments=[ - resolve_attribute_annotation( - argument, - self_annotation=self_annotation, - type_parameters=type_parameters, - module=module, - localns=localns, - recursion_guard=recursion_guard, - ) - for argument in get_args(other) - ], - ) + case typeddict if is_typeddict(typeddict): + resolved_attribute.origin = typeddict + resolved_attribute.arguments = [ + resolve_attribute_annotation( + argument, + type_parameters=type_parameters, + module=module, + localns=localns, + recursion_guard=recursion_guard, + ) + for argument in get_args(typeddict) + ] + resolved_attribute.extra = { + key: resolve_attribute_annotation( + annotation, + type_parameters=type_parameters, + module=getattr(typeddict, "__module__", module), + localns=localns, + recursion_guard=recursion_guard, + ).update_required(key in getattr(typeddict, "__required_keys__", {})) + for key, annotation in typeddict.__annotations__.items() + } + + case haiway_types.Missing: + # special case - attributes marked as missing are not required + # Missing does not work properly within TypedDict though + resolved_attribute.origin = haiway_types.Missing + resolved_attribute.required = False + + case other: # finally use whatever there was + resolved_attribute.origin = other + resolved_attribute.arguments = [ + resolve_attribute_annotation( + argument, + type_parameters=type_parameters, + module=getattr(other, "__module__", module), + localns=localns, + recursion_guard=recursion_guard, + ) + for argument in get_args(other) + ] + + if resolved_attribute.origin is MISSING: + raise NotImplementedError(f"Unresolved annotation: {annotation}") + + return resolved_attribute diff --git a/src/haiway/state/structure.py b/src/haiway/state/structure.py index 4c5ef20..a91885c 100644 --- a/src/haiway/state/structure.py +++ b/src/haiway/state/structure.py @@ -78,9 +78,10 @@ def __new__( if ((get_origin(annotation) or annotation) is ClassVar) or key.startswith("__"): continue + default: Any = getattr(state_type, key, MISSING) attributes[key] = StateAttribute( - annotation=annotation, - default=getattr(state_type, key, MISSING), + annotation=annotation.update_required(default is MISSING), + default=default, validator=attribute_validator(annotation), ) diff --git a/src/haiway/state/validation.py b/src/haiway/state/validation.py index f757920..64e4eb2 100644 --- a/src/haiway/state/validation.py +++ b/src/haiway/state/validation.py @@ -3,8 +3,8 @@ from enum import Enum from pathlib import Path from re import Pattern -from types import MappingProxyType, NoneType, UnionType -from typing import Any, Literal, Protocol, Union +from types import EllipsisType, MappingProxyType, NoneType, UnionType +from typing import Any, Literal, Protocol, Union, is_typeddict from typing import Mapping as MappingType # noqa: UP035 from typing import Sequence as SequenceType # noqa: UP035 from typing import Sequence as SetType # noqa: UP035 @@ -22,12 +22,18 @@ def attribute_validator( annotation: AttributeAnnotation, /, ) -> Callable[[Any], Any]: + if isinstance(annotation.origin, NotImplementedError | RuntimeError): + raise annotation.origin # raise an error if origin was not properly resolved + if validator := VALIDATORS.get(annotation.origin): return validator(annotation) elif hasattr(annotation.origin, "__IMMUTABLE__"): return _prepare_validator_of_type(annotation) + elif is_typeddict(annotation.origin): + return _prepare_validator_of_typed_dict(annotation) + elif issubclass(annotation.origin, Protocol): return _prepare_validator_of_type(annotation) @@ -86,7 +92,7 @@ def _prepare_validator_of_literal( annotation: AttributeAnnotation, /, ) -> Callable[[Any], Any]: - elements: list[Any] = annotation.arguments + elements: Sequence[Any] = annotation.arguments formatted_type: str = str(annotation) def validator( @@ -108,7 +114,7 @@ def _prepare_validator_of_type( validated_type: type[Any] = annotation.origin formatted_type: str = str(annotation) - def type_validator( + def validator( value: Any, ) -> Any: match value: @@ -118,7 +124,7 @@ def type_validator( case _: raise TypeError(f"'{value}' is not matching expected type of '{formatted_type}'") - return type_validator + return validator def _prepare_validator_of_set( @@ -187,7 +193,10 @@ def _prepare_validator_of_tuple( annotation: AttributeAnnotation, /, ) -> Callable[[Any], Any]: - if annotation.arguments[-1].origin == Ellipsis: + if ( + annotation.arguments[-1].origin == Ellipsis + or annotation.arguments[-1].origin == EllipsisType + ): element_validator: Callable[[Any], Any] = attribute_validator(annotation.arguments[0]) formatted_type: str = str(annotation) @@ -280,6 +289,44 @@ def validator( return validator +def _prepare_validator_of_typed_dict( + annotation: AttributeAnnotation, + /, +) -> Callable[[Any], Any]: + def key_validator( + value: Any, + ) -> Any: + match value: + case value if isinstance(value, str): + return value + + case _: + raise TypeError(f"'{value}' is not matching expected type of 'str'") + + values_validators: dict[str, Callable[[Any], Any]] = { + key: attribute_validator(element) for key, element in annotation.extra.items() + } + formatted_type: str = str(annotation) + + def validator( + value: Any, + ) -> Any: + match value: + case {**elements}: + validated: dict[str, Any] = {} + for key, validator in values_validators.items(): + element: Any = elements.get(key, MISSING) + if element is not MISSING: + validated[key_validator(key)] = validator(element) + + return MappingProxyType(validated) + + case _: + raise TypeError(f"'{value}' is not matching expected type of '{formatted_type}'") + + return validator + + VALIDATORS: Mapping[Any, Callable[[AttributeAnnotation], Callable[[Any], Any]]] = { Any: _prepare_validator_of_any, NoneType: _prepare_validator_of_none, diff --git a/tests/test_state.py b/tests/test_state.py index e09e098..2c0e9f8 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -1,7 +1,7 @@ from collections.abc import Callable, Sequence, Set from datetime import date, datetime from enum import StrEnum -from typing import Literal, Protocol, Self, runtime_checkable +from typing import Literal, NotRequired, Protocol, Required, Self, TypedDict, runtime_checkable from uuid import UUID, uuid4 from haiway import MISSING, Missing, State, frozenlist @@ -10,7 +10,16 @@ def test_basic_initializes_with_arguments() -> None: class Selection(StrEnum): A = "A" - B = "A" + B = "B" + + class Nes(State): + val: str + + class TypedValues(TypedDict): + val: str + mis: int | Missing + req: Required[Nes] + nreq: NotRequired[bool] @runtime_checkable class Proto(Protocol): @@ -32,6 +41,7 @@ class Basics(State): function: Callable[[], None] proto: Proto selection: Selection + typeddict: TypedValues basic = Basics( uuid=uuid4(), @@ -49,6 +59,11 @@ class Basics(State): function=lambda: None, proto=lambda: None, selection=Selection.A, + typeddict={ + "val": "ok", + "mis": 42, + "req": Nes(val="ok"), + }, ) assert basic.string == "string" assert basic.literal == "A" @@ -60,6 +75,11 @@ class Basics(State): assert basic.optional == "optional" assert basic.none is None assert basic.selection == Selection.A + assert basic.typeddict == TypedValues( + val="ok", + mis=42, + req=Nes(val="ok"), + ) assert callable(basic.function) assert isinstance(basic.proto, Proto)