diff --git a/pyproject.toml b/pyproject.toml index fc2d979..8b1fb32 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" [project] name = "haiway" description = "Framework for dependency injection and state management within structured concurrency model." -version = "0.6.4" +version = "0.6.5" readme = "README.md" maintainers = [ { name = "Kacper KaliƄski", email = "kacper.kalinski@miquido.com" }, diff --git a/src/haiway/state/attributes.py b/src/haiway/state/attributes.py index 555911f..c740254 100644 --- a/src/haiway/state/attributes.py +++ b/src/haiway/state/attributes.py @@ -1,21 +1,28 @@ -import sys import types import typing -from collections.abc import Mapping -from types import NoneType, UnionType +from collections.abc import Callable, Mapping, MutableMapping, Sequence +from types import GenericAlias, NoneType, UnionType from typing import ( Any, ClassVar, ForwardRef, Generic, Literal, + ParamSpec, + Self, TypeAliasType, TypeVar, + TypeVarTuple, + _GenericAlias, # pyright: ignore get_args, get_origin, get_type_hints, + is_typeddict, ) +from haiway import types as haiway_types +from haiway.types import MISSING + __all__ = [ "AttributeAnnotation", "attribute_annotations", @@ -27,11 +34,26 @@ class AttributeAnnotation: def __init__( self, *, + module: str, origin: Any, - arguments: list[Any], + arguments: Sequence[Any] | None = None, + required: bool = True, + extra: Mapping[str, Any] | None = None, ) -> None: + self.module: str = module self.origin: Any = origin - self.arguments: list[Any] = arguments + self.arguments: Sequence[Any] = arguments or () + self.required: bool = required + self.extra: Mapping[str, Any] = extra or {} + + def update_required( + self, + required: bool, + /, + ) -> Self: + self.required = self.required and required + + return self def __eq__( self, @@ -44,323 +66,645 @@ def __eq__( ) def __str__(self) -> str: - return f"{getattr(self.origin, "__name__", str(self.origin))}" + ( + return f"{self.module}.{getattr(self.origin, "__name__", str(self.origin))}" + ( ("[" + ", ".join(str(arg) for arg in self.arguments) + "]") if self.arguments else "" ) + # def __copy__(self) -> Self: + # return self.__class__( + # module=self.module, + # origin=self.origin, + # arguments=self.arguments, + # required=True, # reset required for copy + # extra=self.extra, + # ) + def attribute_annotations( cls: type[Any], /, - type_parameters: dict[str, Any] | None = None, -) -> dict[str, AttributeAnnotation]: + type_parameters: Mapping[str, Any] | None = None, +) -> Mapping[str, AttributeAnnotation]: type_parameters = type_parameters or {} self_annotation = AttributeAnnotation( + module=cls.__module__, origin=cls, arguments=[], # ignore self arguments here, State will have them resolved at this stage ) - localns: dict[str, Any] = {cls.__name__: cls} - recursion_guard: dict[Any, AttributeAnnotation] = {cls: self_annotation} - attributes: dict[str, AttributeAnnotation] = {} - for key, annotation in get_type_hints(cls, localns=localns).items(): + attributes: dict[str, AttributeAnnotation] = {} + for key, annotation in get_type_hints(cls, localns={cls.__name__: cls}).items(): # do not include ClassVars, private or dunder items if ((get_origin(annotation) or annotation) is ClassVar) or key.startswith("_"): continue attributes[key] = resolve_attribute_annotation( annotation, - self_annotation=self_annotation, type_parameters=type_parameters, - module=cls.__module__, - localns=localns, - recursion_guard=recursion_guard, + module=cls.__module__, # getattr(annotation, "__module__", cls.__module__), + self_annotation=self_annotation, + recursion_guard={cls.__qualname__: self_annotation}, ) return attributes -def resolve_attribute_annotation( # noqa: C901, PLR0911, PLR0912, PLR0913 +def _resolve_none( annotation: Any, +) -> AttributeAnnotation: + return AttributeAnnotation( + module=NoneType.__module__, + origin=NoneType, + ) + + +def _resolve_missing( + annotation: Any, +) -> AttributeAnnotation: + # special case - attributes marked as missing are not required + # Missing does not work properly within TypedDict though + return AttributeAnnotation( + module=haiway_types.Missing.__module__, + origin=haiway_types.Missing, + required=False, + ) + + +def _resolve_literal( + annotation: Any, +) -> AttributeAnnotation: + return AttributeAnnotation( + module=Literal.__module__, + origin=Literal, + arguments=get_args(annotation), + ) + + +def _resolve_forward_ref( + annotation: ForwardRef | str, /, - self_annotation: AttributeAnnotation | None, - type_parameters: dict[str, Any], module: str, - localns: dict[str, Any], - recursion_guard: Mapping[Any, AttributeAnnotation], # TODO: verify recursion! + type_parameters: Mapping[str, Any], + self_annotation: AttributeAnnotation | None, + recursion_guard: MutableMapping[str, AttributeAnnotation], ) -> AttributeAnnotation: - # resolve annotation directly if able + forward_ref: ForwardRef match annotation: - # None - case types.NoneType | types.NoneType(): - return AttributeAnnotation( - origin=NoneType, - arguments=[], - ) + case str() as string: + forward_ref = ForwardRef(string, module=module) + + case reference: + forward_ref = reference + + if evaluated := forward_ref._evaluate( + globalns=None, + localns=None, + recursive_guard=frozenset(), + ): + return resolve_attribute_annotation( + evaluated, + type_parameters=type_parameters, + module=module, + self_annotation=self_annotation, + recursion_guard=recursion_guard, + ) - # forward reference through string - case str() as forward_ref: - return resolve_attribute_annotation( - ForwardRef(forward_ref, module=module)._evaluate( - globalns=None, - localns=localns, - recursive_guard=frozenset(), - ), - self_annotation=self_annotation, - type_parameters=type_parameters, - module=module, - localns=localns, - recursion_guard=recursion_guard, # we might need to update it somehow? - ) + else: + raise RuntimeError(f"Cannot resolve annotation of {annotation}") - # forward reference directly - case typing.ForwardRef() as reference: - return resolve_attribute_annotation( - reference._evaluate( - globalns=None, - localns=localns, - recursive_guard=frozenset(), - ), - self_annotation=self_annotation, + +def _resolve_generic_alias( # noqa: PLR0911 + annotation: GenericAlias, + /, + module: str, + type_parameters: Mapping[str, Any], + self_annotation: AttributeAnnotation | None, + recursion_guard: MutableMapping[str, AttributeAnnotation], +) -> AttributeAnnotation: + match get_origin(annotation): + case TypeAliasType() as origin: # pyright: ignore[reportUnnecessaryComparison] + return _resolve_type_alias( + origin, type_parameters=type_parameters, - module=module, - localns=localns, - recursion_guard=recursion_guard, # we might need to update it somehow? + module=origin.__module__ or module, + self_annotation=self_annotation, + recursion_guard=recursion_guard, ) - # generic alias aka parametrized type - case types.GenericAlias() as generic_alias: - match get_origin(generic_alias): - # check for an alias with parameters - case typing.TypeAliasType() as alias: # pyright: ignore[reportUnnecessaryComparison] - type_alias: AttributeAnnotation = AttributeAnnotation( - origin=TypeAliasType, - arguments=[], + case origin if issubclass(origin, Generic): + match origin.__class_getitem__( # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue] + *( + type_parameters.get( + arg.__name__, + arg.__bound__ or Any, ) - resolved: AttributeAnnotation = resolve_attribute_annotation( - alias.__value__, - self_annotation=None, - type_parameters=type_parameters, - module=module, - localns=localns, - recursion_guard=recursion_guard, + if isinstance(arg, TypeVar) + else arg + for arg in get_args(annotation) + ) + ): + case GenericAlias(): + resolved_attribute = AttributeAnnotation( + module=getattr(origin, "__module__", ""), + origin=origin, ) - type_alias.origin = resolved.origin - type_alias.arguments = resolved.arguments - return type_alias - - # check if we can resolve it as generic - case parametrized if issubclass(parametrized, Generic): - parametrized_type: Any = parametrized.__class_getitem__( # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue] - *( - type_parameters.get( - arg.__name__, - arg.__bound__ or Any, - ) - if isinstance(arg, TypeVar) - else arg - for arg in get_args(generic_alias) + if recursion_name := getattr( + origin, + "__qualname__", + getattr( + origin, + "__name__", + None, + ), + ): + if recursive := recursion_guard.get(recursion_name): + return recursive + + else: + recursion_guard[recursion_name] = resolved_attribute + + resolved_attribute.arguments = [ + resolve_attribute_annotation( + argument, + type_parameters=type_parameters, + module=module, + self_annotation=self_annotation, + recursion_guard=recursion_guard, ) - ) + for argument in get_args(annotation) + ] - match parametrized_type: - # verify if we got any specific type or generic alias again - case types.GenericAlias(): - return AttributeAnnotation( - origin=parametrized, - arguments=[ - resolve_attribute_annotation( - argument, - self_annotation=self_annotation, - type_parameters=type_parameters, - module=module, - localns=localns, - recursion_guard=recursion_guard, - ) - for argument in get_args(generic_alias) - ], - ) - - # use resolved type if it is not an alias again - case _: - return AttributeAnnotation( - origin=parametrized_type, - arguments=[], - ) - - # anything else - try to resolve a concrete type or use as is - case origin: - return AttributeAnnotation( - origin=origin, - arguments=[ - resolve_attribute_annotation( - argument, - self_annotation=self_annotation, - type_parameters=type_parameters, - module=module, - localns=localns, - recursion_guard=recursion_guard, - ) - for argument in get_args(generic_alias) - ], + return resolved_attribute + + # use resolved type if it is not an alias again + case resolved: # pyright: ignore + resolved_attribute = AttributeAnnotation( + module=getattr(resolved, "__module__", ""), # pyright: ignore + origin=resolved, ) + if recursive := recursion_guard.get(getattr(resolved, "__qualname__", "")): # pyright: ignore + return recursive + + else: + recursion_guard[getattr(resolved, "__qualname__", "")] = resolved_attribute # pyright: ignore + + resolved_attribute.arguments = [ + resolve_attribute_annotation( + argument, + type_parameters=type_parameters, + module=module, + self_annotation=self_annotation, + recursion_guard=recursion_guard, + ) + for argument in get_args(annotation) + ] + + return resolved_attribute - # type alias - case typing.TypeAliasType() as alias: - type_alias: AttributeAnnotation = AttributeAnnotation( - origin=TypeAliasType, - arguments=[], + case origin: + resolved_attribute = AttributeAnnotation( + module=getattr(origin, "__module__", ""), + origin=origin, ) - resolved: AttributeAnnotation = resolve_attribute_annotation( - alias.__value__, - self_annotation=None, + if recursive := recursion_guard.get(getattr(origin, "__qualname__", "")): + return recursive + + resolved_attribute.arguments = [ + resolve_attribute_annotation( + argument, + type_parameters=type_parameters, + module=module, + self_annotation=self_annotation, + recursion_guard=recursion_guard, + ) + for argument in get_args(annotation) + ] + + return resolved_attribute + + +def _resolve_special_generic_alias( + annotation: Any, + /, + module: str, + type_parameters: Mapping[str, Any], + self_annotation: AttributeAnnotation | None, + recursion_guard: MutableMapping[str, AttributeAnnotation], +) -> AttributeAnnotation: + origin: type[Any] = get_origin(annotation) + resolved_attribute = AttributeAnnotation( + module=getattr(origin, "__module__", ""), + origin=origin, + ) + if recursive := recursion_guard.get(getattr(origin, "__qualname__", "")): # pyright: ignore + return recursive + + else: + recursion_guard[getattr(origin, "__qualname__", "")] = resolved_attribute # pyright: ignore + + resolved_attribute.arguments = [ + resolve_attribute_annotation( + argument, + type_parameters=type_parameters, + module=module, + self_annotation=self_annotation, + recursion_guard=recursion_guard, + ) + for argument in get_args(annotation) + ] + + return resolved_attribute + + +def _resolve_type_alias( + annotation: TypeAliasType, + /, + module: str, + type_parameters: Mapping[str, Any], + self_annotation: AttributeAnnotation | None, + recursion_guard: MutableMapping[str, AttributeAnnotation], +) -> AttributeAnnotation: + resolved_attribute = AttributeAnnotation( + module=getattr(annotation, "__module__", module), + origin=MISSING, + ) + if recursive := recursion_guard.get(annotation.__name__): + return recursive + + else: + recursion_guard[annotation.__name__] = resolved_attribute + + resolved: AttributeAnnotation = resolve_attribute_annotation( + annotation.__value__, + module=getattr(annotation, "__module__", module), + type_parameters=type_parameters, + self_annotation=self_annotation, + recursion_guard=recursion_guard, + ) + + resolved_attribute.origin = resolved.origin + resolved_attribute.arguments = resolved.arguments + resolved_attribute.extra = resolved.extra + resolved_attribute.required = resolved.required + + return resolved_attribute + + +def _resolve_type_var( + annotation: TypeVar, + /, + module: str, + type_parameters: Mapping[str, Any], + self_annotation: AttributeAnnotation | None, + recursion_guard: MutableMapping[str, AttributeAnnotation], +) -> AttributeAnnotation: + return resolve_attribute_annotation( + type_parameters.get( + annotation.__name__, + # use bound as default or Any otherwise + annotation.__bound__ or Any, + ), + module=module, + type_parameters=type_parameters, + self_annotation=self_annotation, + recursion_guard=recursion_guard, + ) + + +def _resolve_type_union( + annotation: UnionType, + /, + module: str, + type_parameters: Mapping[str, Any], + self_annotation: AttributeAnnotation | None, + recursion_guard: MutableMapping[str, AttributeAnnotation], +) -> AttributeAnnotation: + arguments: Sequence[AttributeAnnotation] = [ + resolve_attribute_annotation( + argument, + type_parameters=type_parameters, + module=module, + self_annotation=self_annotation, + recursion_guard=recursion_guard, + ) + for argument in get_args(annotation) + ] + return AttributeAnnotation( + module=UnionType.__module__, + origin=UnionType, # pyright: ignore[reportArgumentType] + arguments=arguments, + required=all(argument.required for argument in arguments), + ) + + +def _resolve_callable( + annotation: Any, + /, + module: str, + type_parameters: Mapping[str, Any], + self_annotation: AttributeAnnotation | None, + recursion_guard: MutableMapping[str, AttributeAnnotation], +) -> AttributeAnnotation: + return AttributeAnnotation( + module=Callable.__module__, + origin=Callable, + arguments=[ + resolve_attribute_annotation( + argument, type_parameters=type_parameters, module=module, - localns=localns, + self_annotation=self_annotation, recursion_guard=recursion_guard, ) - type_alias.origin = resolved.origin - type_alias.arguments = resolved.arguments - return type_alias - - # type parameter - case typing.TypeVar(): - return resolve_attribute_annotation( - # try to resolve it from current parameters if able - type_parameters.get( - annotation.__name__, - # use bound as default or Any otherwise - annotation.__bound__ or Any, - ), - self_annotation=None, + for argument in get_args(annotation) + ], + ) + + +def _resolve_type_box( + annotation: Any, + /, + module: str, + type_parameters: Mapping[str, Any], + self_annotation: AttributeAnnotation | None, + recursion_guard: MutableMapping[str, AttributeAnnotation], +) -> AttributeAnnotation: + return resolve_attribute_annotation( + get_args(annotation)[0], + type_parameters=type_parameters, + module=module, + self_annotation=self_annotation, + recursion_guard=recursion_guard, + ) + + +def _resolve_type_not_required( + annotation: Any, + /, + module: str, + type_parameters: Mapping[str, Any], + self_annotation: AttributeAnnotation | None, + recursion_guard: MutableMapping[str, AttributeAnnotation], +) -> AttributeAnnotation: + return resolve_attribute_annotation( + get_args(annotation)[0], + type_parameters=type_parameters, + module=module, + self_annotation=self_annotation, + recursion_guard=recursion_guard, + ).update_required(False) + + +def _resolve_type_optional( + annotation: Any, + /, + module: str, + type_parameters: Mapping[str, Any], + self_annotation: AttributeAnnotation | None, + recursion_guard: MutableMapping[str, AttributeAnnotation], +) -> AttributeAnnotation: + return AttributeAnnotation( + module=UnionType.__module__, + origin=UnionType, # pyright: ignore[reportArgumentType] + arguments=[ + resolve_attribute_annotation( + get_args(annotation)[0], type_parameters=type_parameters, module=module, - localns=localns, + self_annotation=self_annotation, recursion_guard=recursion_guard, - ) + ), + AttributeAnnotation( + module=NoneType.__module__, + origin=NoneType, + ), + ], + ) - case typing.ParamSpec(): - sys.stderr.write( - "ParamSpec is not supported for attribute annotations," - " ignoring with Any type - it might incorrectly validate types\n" - ) - return AttributeAnnotation( - origin=Any, - arguments=[], - ) - case typing.TypeVarTuple(): - sys.stderr.write( - "TypeVarTuple is not supported for attribute annotations," - " ignoring with Any type - it might incorrectly validate types\n" - ) - return AttributeAnnotation( - origin=Any, - arguments=[], +def _resolve_type_typeddict( + annotation: Any, + /, + module: str, + type_parameters: Mapping[str, Any], + self_annotation: AttributeAnnotation | None, + recursion_guard: MutableMapping[str, AttributeAnnotation], +) -> AttributeAnnotation: + resolved_attribute = AttributeAnnotation( + module=annotation.__module__, + origin=annotation, + ) + if recursive := recursion_guard.get(annotation.__qualname__): + return recursive + + else: + recursion_guard[annotation.__qualname__] = resolved_attribute + + resolved_attribute.arguments = [ + resolve_attribute_annotation( + argument, + type_parameters=type_parameters, + module=module, + self_annotation=self_annotation, + recursion_guard=recursion_guard, + ) + for argument in get_args(annotation) + ] + + attributes: dict[str, AttributeAnnotation] = {} + for key, element in get_type_hints( + annotation, + localns={annotation.__name__: annotation}, + ).items(): + attributes[key] = resolve_attribute_annotation( + element, + type_parameters=type_parameters, + module=annotation.__module__, + self_annotation=resolved_attribute, + recursion_guard=recursion_guard, + ).update_required(key in annotation.__required_keys__) + resolved_attribute.extra = attributes + return resolved_attribute + + +def _resolve_type( + annotation: Any, + /, + module: str, + type_parameters: Mapping[str, Any], + self_annotation: AttributeAnnotation | None, + recursion_guard: MutableMapping[str, AttributeAnnotation], +) -> AttributeAnnotation: + if recursion_name := getattr( + annotation, + "__qualname__", + getattr( + annotation, + "__name__", + None, + ), + ): + if recursive := recursion_guard.get(recursion_name): + return recursive + + return AttributeAnnotation( + module=getattr(annotation, "__module__", ""), + origin=annotation, + arguments=[ + resolve_attribute_annotation( + argument, + type_parameters=type_parameters, + module=module, + self_annotation=self_annotation, + recursion_guard=recursion_guard, ) + for argument in get_args(annotation) + ], + ) - case _: - pass # proceed to resolving based on origin - # resolve based on origin if any +def resolve_attribute_annotation( # noqa: C901, PLR0911, PLR0912 + annotation: Any, + /, + module: str, + type_parameters: Mapping[str, Any], + self_annotation: AttributeAnnotation | None, + recursion_guard: MutableMapping[str, AttributeAnnotation], +) -> AttributeAnnotation: match get_origin(annotation) or annotation: - case types.UnionType | typing.Union: - return AttributeAnnotation( - origin=UnionType, # pyright: ignore[reportArgumentType] - arguments=[ - recursion_guard.get( - argument, - resolve_attribute_annotation( - argument, - self_annotation=self_annotation, - type_parameters=type_parameters, - module=module, - localns=localns, - recursion_guard=recursion_guard, - ), - ) - for argument in get_args(annotation) - ], + case types.NoneType | None: + return _resolve_none( + annotation=annotation, ) - case typing.Callable: # pyright: ignore[reportUnknownMemberType, reportAttributeAccessIssue] - return AttributeAnnotation( - origin=typing.Callable, - arguments=[ - resolve_attribute_annotation( - argument, - self_annotation=self_annotation, - type_parameters=type_parameters, - module=module, - localns=localns, - recursion_guard=recursion_guard, - ) - for argument in get_args(annotation) - ], + case haiway_types.Missing: + return _resolve_missing( + annotation=annotation, ) - case typing.Self: # pyright: ignore[reportUnknownMemberType] - if not self_annotation: - sys.stderr.write( - "Unresolved Self attribute annotation," - " ignoring with Any type - it might incorrectly validate types\n" - ) - return AttributeAnnotation( - origin=Any, - arguments=[], - ) + case types.UnionType | typing.Union: + return _resolve_type_union( + annotation, + module=module, + type_parameters=type_parameters, + self_annotation=self_annotation, + recursion_guard=recursion_guard, + ) - return self_annotation + case typing.Literal: + return _resolve_literal(annotation) - # unwrap from irrelevant type wrappers - case typing.Annotated | typing.Final | typing.Required | typing.NotRequired: - return resolve_attribute_annotation( - get_args(annotation)[0], + case typeddict if is_typeddict(typeddict): + return _resolve_type_typeddict( + typeddict, + module=module, + type_parameters=type_parameters, + self_annotation=self_annotation, + recursion_guard=recursion_guard, + ) + + case typing.Callable: # pyright: ignore + return _resolve_callable( + annotation, + module=module, + type_parameters=type_parameters, self_annotation=self_annotation, + recursion_guard=recursion_guard, + ) + + case typing.Annotated | typing.Final | typing.Required: + return _resolve_type_box( + annotation, + module=module, type_parameters=type_parameters, + self_annotation=self_annotation, + recursion_guard=recursion_guard, + ) + + case typing.NotRequired: + return _resolve_type_not_required( + annotation, module=module, - localns=localns, + type_parameters=type_parameters, + self_annotation=self_annotation, recursion_guard=recursion_guard, ) case typing.Optional: # optional is a Union[Value, None] - return AttributeAnnotation( - origin=UnionType, # pyright: ignore[reportArgumentType] - arguments=[ - resolve_attribute_annotation( - get_args(annotation)[0], - self_annotation=self_annotation, + return _resolve_type_optional( + annotation, + module=module, + type_parameters=type_parameters, + self_annotation=self_annotation, + recursion_guard=recursion_guard, + ) + + case typing.Self: # pyright: ignore + if self_annotation: + return self_annotation + + else: + raise RuntimeError(f"Unresolved Self annotation: {annotation}") + + case _: + match annotation: + case str() | ForwardRef(): + return _resolve_forward_ref( + annotation, + module=module, type_parameters=type_parameters, + self_annotation=self_annotation, + recursion_guard=recursion_guard, + ) + + case GenericAlias(): + return _resolve_generic_alias( + annotation, module=module, - localns=localns, + type_parameters=type_parameters, + self_annotation=self_annotation, recursion_guard=recursion_guard, - ), - AttributeAnnotation( - origin=NoneType, - arguments=[], - ), - ], - ) + ) - case typing.Literal: - return AttributeAnnotation( - origin=Literal, - arguments=list(get_args(annotation)), - ) + case _GenericAlias(): + return _resolve_special_generic_alias( + annotation, + module=module, + type_parameters=type_parameters, + self_annotation=self_annotation, + recursion_guard=recursion_guard, + ) - case other: # finally use whatever there was - return AttributeAnnotation( - origin=other, - arguments=[ - resolve_attribute_annotation( - argument, + case TypeAliasType(): + return _resolve_type_alias( + annotation, + module=module, + type_parameters=type_parameters, self_annotation=self_annotation, + recursion_guard=recursion_guard, + ) + + case TypeVar(): + return _resolve_type_var( + annotation, + module=module, type_parameters=type_parameters, + self_annotation=self_annotation, + recursion_guard=recursion_guard, + ) + + case ParamSpec(): + raise NotImplementedError(f"Unresolved ParamSpec annotation: {annotation}") + + case TypeVarTuple(): + raise NotImplementedError(f"Unresolved TypeVarTuple annotation: {annotation}") + + case _: # finally use whatever there was + return _resolve_type( + annotation, module=module, - localns=localns, + type_parameters=type_parameters, + self_annotation=self_annotation, recursion_guard=recursion_guard, ) - for argument in get_args(other) - ], - ) diff --git a/src/haiway/state/structure.py b/src/haiway/state/structure.py index 4c5ef20..b0f39a1 100644 --- a/src/haiway/state/structure.py +++ b/src/haiway/state/structure.py @@ -1,4 +1,3 @@ -from collections.abc import Callable from copy import deepcopy from types import EllipsisType, GenericAlias from typing import ( @@ -16,7 +15,11 @@ from haiway.state.attributes import AttributeAnnotation, attribute_annotations from haiway.state.path import AttributePath -from haiway.state.validation import attribute_validator +from haiway.state.validation import ( + AttributeValidation, + AttributeValidationContext, + AttributeValidator, +) from haiway.types import MISSING, Missing, not_missing __all__ = [ @@ -28,20 +31,27 @@ class StateAttribute[Value]: def __init__( self, + name: str, annotation: AttributeAnnotation, default: Value | Missing, - validator: Callable[[Any], Value], + validator: AttributeValidation[Value], ) -> None: + self.name: str = name self.annotation: AttributeAnnotation = annotation self.default: Value | Missing = default - self.validator: Callable[[Any], Value] = validator + self.validator: AttributeValidation[Value] = validator def validated( self, value: Any | Missing, /, + context: AttributeValidationContext, ) -> Value: - return self.validator(self.default if value is MISSING else value) + with context.scope(f".{self.name}"): + return self.validator( + self.default if value is MISSING else value, + context=context, + ) @dataclass_transform( @@ -78,10 +88,13 @@ def __new__( if ((get_origin(annotation) or annotation) is ClassVar) or key.startswith("__"): continue + default: Any = getattr(state_type, key, MISSING) attributes[key] = StateAttribute( - annotation=annotation, - default=getattr(state_type, key, MISSING), - validator=attribute_validator(annotation), + name=key, + annotation=annotation.update_required(default is MISSING), + default=default, + # TODO: add self to recursion guard + validator=AttributeValidator.of(annotation, recursion_guard={}), ) state_type.__ATTRIBUTES__ = attributes # pyright: ignore[reportAttributeAccessIssue] @@ -170,16 +183,18 @@ def __init__( **kwargs: Any, ) -> None: for name, attribute in self.__ATTRIBUTES__.items(): - object.__setattr__( - self, # pyright: ignore[reportUnknownArgumentType] - name, - attribute.validated( - kwargs.get( - name, - MISSING, + with AttributeValidationContext().scope(self.__class__.__qualname__) as context: + object.__setattr__( + self, # pyright: ignore[reportUnknownArgumentType] + name, + attribute.validated( + kwargs.get( + name, + MISSING, + ), + context=context, ), - ), - ) + ) def updating[Value]( self, diff --git a/src/haiway/state/validation.py b/src/haiway/state/validation.py index f757920..01bc270 100644 --- a/src/haiway/state/validation.py +++ b/src/haiway/state/validation.py @@ -1,10 +1,11 @@ -from collections.abc import Callable, Mapping, Sequence, Set +from collections import deque +from collections.abc import Callable, Mapping, MutableMapping, Sequence, Set from datetime import date, datetime, time, timedelta, timezone from enum import Enum from pathlib import Path from re import Pattern -from types import MappingProxyType, NoneType, UnionType -from typing import Any, Literal, Protocol, Union +from types import EllipsisType, MappingProxyType, NoneType, TracebackType, UnionType +from typing import Any, Literal, Protocol, Self, Union, is_typeddict from typing import Mapping as MappingType # noqa: UP035 from typing import Sequence as SequenceType # noqa: UP035 from typing import Sequence as SetType # noqa: UP035 @@ -14,36 +15,145 @@ from haiway.types import MISSING, Missing __all__ = [ - "attribute_validator", + "AttributeValidation", + "AttributeValidationContext", + "AttributeValidationError", + "AttributeValidator", ] -def attribute_validator( - annotation: AttributeAnnotation, - /, -) -> Callable[[Any], Any]: - if validator := VALIDATORS.get(annotation.origin): - return validator(annotation) +class AttributeValidationContext: + def __init__(self) -> None: + self._path: deque[str] = deque() + + def __str__(self) -> str: + return "".join(self._path) + + def scope( + self, + path: str, + /, + ) -> "AttributeValidationContextScope": + return AttributeValidationContextScope(self, component=path) + + +class AttributeValidationContextScope: + def __init__( + self, + context: AttributeValidationContext, + /, + component: str, + ) -> None: + self.context: AttributeValidationContext = context + self.component: str = component + + def __enter__(self) -> AttributeValidationContext: + self.context._path.append(self.component) # pyright: ignore[reportPrivateUsage] + return self.context + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + try: + if exc_val is not None and exc_type is not AttributeValidationError: + raise AttributeValidationError( + f"Validation error at {self.context!s}", + ) from exc_val + + finally: + self.context._path.pop() # pyright: ignore # FIXME + + +class AttributeValidation[Type](Protocol): + def __call__( + self, + value: Any, + /, + *, + context: AttributeValidationContext, + ) -> Type: ... + + +class AttributeValidationError(Exception): + pass + + +class AttributeValidator[Type]: + @classmethod + def of( + cls, + annotation: AttributeAnnotation, + /, + *, + recursion_guard: MutableMapping[str, AttributeValidation[Any]], + ) -> AttributeValidation[Any]: + if isinstance(annotation.origin, NotImplementedError | RuntimeError): + raise annotation.origin # raise an error if origin was not properly resolved + + if recursive := recursion_guard.get(str(annotation)): + return recursive + + validator: Self = cls( + annotation, + validation=MISSING, + ) + recursion_guard[str(annotation)] = validator - elif hasattr(annotation.origin, "__IMMUTABLE__"): - return _prepare_validator_of_type(annotation) + if common := VALIDATORS.get(annotation.origin): + validator.validation = common(annotation, recursion_guard) - elif issubclass(annotation.origin, Protocol): - return _prepare_validator_of_type(annotation) + elif hasattr(annotation.origin, "__IMMUTABLE__"): + validator.validation = _prepare_validator_of_type(annotation, recursion_guard) - elif issubclass(annotation.origin, Enum): - return _prepare_validator_of_type(annotation) + elif is_typeddict(annotation.origin): + validator.validation = _prepare_validator_of_typed_dict(annotation, recursion_guard) - else: - raise TypeError(f"Unsupported type annotation: {annotation}") + elif issubclass(annotation.origin, Protocol): + validator.validation = _prepare_validator_of_type(annotation, recursion_guard) + + elif issubclass(annotation.origin, Enum): + validator.validation = _prepare_validator_of_type(annotation, recursion_guard) + + else: + raise TypeError(f"Unsupported type annotation: {annotation}") + + return validator + + def __init__( + self, + annotation: AttributeAnnotation, + validation: AttributeValidation[Type] | Missing, + ) -> None: + self.annotation: AttributeAnnotation = annotation + self.validation: AttributeValidation[Type] | Missing = validation + + def __call__( + self, + value: Any, + /, + *, + context: AttributeValidationContext, + ) -> Any: + assert self.validation is not MISSING # nosec: B101 + return self.validation( # pyright: ignore[reportCallIssue, reportUnknownVariableType] + value, + context=context, + ) def _prepare_validator_of_any( annotation: AttributeAnnotation, /, -) -> Callable[[Any], Any]: + recursion_guard: MutableMapping[str, AttributeValidation[Any]], +) -> AttributeValidation[Any]: def validator( value: Any, + /, + *, + context: AttributeValidationContext, ) -> Any: return value # any is always valid @@ -53,9 +163,13 @@ def validator( def _prepare_validator_of_none( annotation: AttributeAnnotation, /, -) -> Callable[[Any], Any]: + recursion_guard: MutableMapping[str, AttributeValidation[Any]], +) -> AttributeValidation[Any]: def validator( value: Any, + /, + *, + context: AttributeValidationContext, ) -> Any: if value is None: return value @@ -69,9 +183,13 @@ def validator( def _prepare_validator_of_missing( annotation: AttributeAnnotation, /, -) -> Callable[[Any], Any]: + recursion_guard: MutableMapping[str, AttributeValidation[Any]], +) -> AttributeValidation[Any]: def validator( value: Any, + /, + *, + context: AttributeValidationContext, ) -> Any: if value is MISSING: return value @@ -85,12 +203,16 @@ def validator( def _prepare_validator_of_literal( annotation: AttributeAnnotation, /, -) -> Callable[[Any], Any]: - elements: list[Any] = annotation.arguments + recursion_guard: MutableMapping[str, AttributeValidation[Any]], +) -> AttributeValidation[Any]: + elements: Sequence[Any] = annotation.arguments formatted_type: str = str(annotation) def validator( value: Any, + /, + *, + context: AttributeValidationContext, ) -> Any: if value in elements: return value @@ -104,12 +226,16 @@ def validator( def _prepare_validator_of_type( annotation: AttributeAnnotation, /, -) -> Callable[[Any], Any]: + recursion_guard: MutableMapping[str, AttributeValidation[Any]], +) -> AttributeValidation[Any]: validated_type: type[Any] = annotation.origin formatted_type: str = str(annotation) - def type_validator( + def validator( value: Any, + /, + *, + context: AttributeValidationContext, ) -> Any: match value: case value if isinstance(value, validated_type): @@ -118,21 +244,28 @@ def type_validator( case _: raise TypeError(f"'{value}' is not matching expected type of '{formatted_type}'") - return type_validator + return validator def _prepare_validator_of_set( annotation: AttributeAnnotation, /, -) -> Callable[[Any], Any]: - element_validator: Callable[[Any], Any] = attribute_validator(annotation.arguments[0]) + recursion_guard: MutableMapping[str, AttributeValidation[Any]], +) -> AttributeValidation[Any]: + element_validator: AttributeValidation[Any] = AttributeValidator.of( + annotation.arguments[0], + recursion_guard=recursion_guard, + ) formatted_type: str = str(annotation) def validator( value: Any, + /, + *, + context: AttributeValidationContext, ) -> Any: if isinstance(value, set): - return frozenset(element_validator(element) for element in value) # pyright: ignore[reportUnknownVariableType] + return frozenset(element_validator(element, context=context) for element in value) # pyright: ignore[reportUnknownVariableType] else: raise TypeError(f"'{value}' is not matching expected type of '{formatted_type}'") @@ -143,16 +276,24 @@ def validator( def _prepare_validator_of_sequence( annotation: AttributeAnnotation, /, -) -> Callable[[Any], Any]: - element_validator: Callable[[Any], Any] = attribute_validator(annotation.arguments[0]) + recursion_guard: MutableMapping[str, AttributeValidation[Any]], +) -> AttributeValidation[Any]: + element_validator: AttributeValidation[Any] = AttributeValidator.of( + annotation.arguments[0], + recursion_guard=recursion_guard, + ) formatted_type: str = str(annotation) def validator( value: Any, + /, + *, + context: AttributeValidationContext, ) -> Any: match value: case [*elements]: - return tuple(element_validator(element) for element in elements) + # TODO: context indices + return tuple(element_validator(element, context=context) for element in elements) case _: raise TypeError(f"'{value}' is not matching expected type of '{formatted_type}'") @@ -163,18 +304,32 @@ def validator( def _prepare_validator_of_mapping( annotation: AttributeAnnotation, /, -) -> Callable[[Any], Any]: - key_validator: Callable[[Any], Any] = attribute_validator(annotation.arguments[0]) - value_validator: Callable[[Any], Any] = attribute_validator(annotation.arguments[1]) + recursion_guard: MutableMapping[str, AttributeValidation[Any]], +) -> AttributeValidation[Any]: + key_validator: AttributeValidation[Any] = AttributeValidator.of( + annotation.arguments[0], + recursion_guard=recursion_guard, + ) + value_validator: AttributeValidation[Any] = AttributeValidator.of( + annotation.arguments[1], + recursion_guard=recursion_guard, + ) formatted_type: str = str(annotation) def validator( value: Any, + /, + *, + context: AttributeValidationContext, ) -> Any: match value: case {**elements}: + # TODO: context keys return MappingProxyType( - {key_validator(key): value_validator(value) for key, value in elements.items()} + { + key_validator(key, context=context): value_validator(value, context=context) + for key, value in elements.items() + } ) case _: @@ -186,17 +341,30 @@ def validator( def _prepare_validator_of_tuple( annotation: AttributeAnnotation, /, -) -> Callable[[Any], Any]: - if annotation.arguments[-1].origin == Ellipsis: - element_validator: Callable[[Any], Any] = attribute_validator(annotation.arguments[0]) + recursion_guard: MutableMapping[str, AttributeValidation[Any]], +) -> AttributeValidation[Any]: + if ( + annotation.arguments[-1].origin == Ellipsis + or annotation.arguments[-1].origin == EllipsisType + ): + element_validator: AttributeValidation[Any] = AttributeValidator.of( + annotation.arguments[0], + recursion_guard=recursion_guard, + ) formatted_type: str = str(annotation) def validator( value: Any, + /, + *, + context: AttributeValidationContext, ) -> Any: match value: case [*elements]: - return tuple(element_validator(element) for element in elements) + # TODO: context indices + return tuple( + element_validator(element, context=context) for element in elements + ) case _: raise TypeError( @@ -206,14 +374,18 @@ def validator( return validator else: - element_validators: list[Callable[[Any], Any]] = [ - attribute_validator(alternative) for alternative in annotation.arguments + element_validators: list[AttributeValidation[Any]] = [ + AttributeValidator.of(alternative, recursion_guard=recursion_guard) + for alternative in annotation.arguments ] elements_count: int = len(element_validators) formatted_type: str = str(annotation) def validator( value: Any, + /, + *, + context: AttributeValidationContext, ) -> Any: match value: case [*elements]: @@ -222,8 +394,10 @@ def validator( f"'{value}' is not matching expected type of '{formatted_type}'" ) + # TODO: context indices return tuple( - element_validators[idx](value) for idx, value in enumerate(elements) + element_validators[idx](value, context=context) + for idx, value in enumerate(elements) ) case _: @@ -237,19 +411,24 @@ def validator( def _prepare_validator_of_union( annotation: AttributeAnnotation, /, -) -> Callable[[Any], Any]: - validators: list[Callable[[Any], Any]] = [ - attribute_validator(alternative) for alternative in annotation.arguments + recursion_guard: MutableMapping[str, AttributeValidation[Any]], +) -> AttributeValidation[Any]: + validators: list[AttributeValidation[Any]] = [ + AttributeValidator.of(alternative, recursion_guard=recursion_guard) + for alternative in annotation.arguments ] formatted_type: str = str(annotation) def validator( value: Any, + /, + *, + context: AttributeValidationContext, ) -> Any: errors: list[Exception] = [] for validator in validators: try: - return validator(value) + return validator(value, context=context) except Exception as exc: errors.append(exc) @@ -265,13 +444,18 @@ def validator( def _prepare_validator_of_callable( annotation: AttributeAnnotation, /, -) -> Callable[[Any], Any]: + recursion_guard: MutableMapping[str, AttributeValidation[Any]], +) -> AttributeValidation[Any]: formatted_type: str = str(annotation) def validator( value: Any, + /, + *, + context: AttributeValidationContext, ) -> Any: if callable(value): + # TODO: we could verify callable signature here return value else: @@ -280,7 +464,65 @@ def validator( return validator -VALIDATORS: Mapping[Any, Callable[[AttributeAnnotation], Callable[[Any], Any]]] = { +def _prepare_validator_of_typed_dict( + annotation: AttributeAnnotation, + /, + recursion_guard: MutableMapping[str, AttributeValidation[Any]], +) -> AttributeValidation[Any]: + def key_validator( + value: Any, + /, + *, + context: AttributeValidationContext, + ) -> Any: + match value: + case value if isinstance(value, str): + return value + + case _: + raise TypeError(f"'{value}' is not matching expected type of 'str'") + + values_validators: dict[str, AttributeValidation[Any]] = { + key: AttributeValidator.of(element, recursion_guard=recursion_guard) + for key, element in annotation.extra.items() + } + formatted_type: str = str(annotation) + + def validator( + value: Any, + /, + *, + context: AttributeValidationContext, + ) -> Any: + match value: + case {**elements}: + validated: dict[str, Any] = {} + for key, validator in values_validators.items(): + element: Any = elements.get(key, MISSING) + if element is not MISSING: + # TODO: context keys + validated[key_validator(key, context=context)] = validator( + element, context=context + ) + + return MappingProxyType(validated) + + case _: + raise TypeError(f"'{value}' is not matching expected type of '{formatted_type}'") + + return validator + + +VALIDATORS: Mapping[ + Any, + Callable[ + [ + AttributeAnnotation, + MutableMapping[str, AttributeValidation[Any]], + ], + AttributeValidation[Any], + ], +] = { Any: _prepare_validator_of_any, NoneType: _prepare_validator_of_none, Missing: _prepare_validator_of_missing, diff --git a/tests/test_state.py b/tests/test_state.py index e09e098..2c0e9f8 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -1,7 +1,7 @@ from collections.abc import Callable, Sequence, Set from datetime import date, datetime from enum import StrEnum -from typing import Literal, Protocol, Self, runtime_checkable +from typing import Literal, NotRequired, Protocol, Required, Self, TypedDict, runtime_checkable from uuid import UUID, uuid4 from haiway import MISSING, Missing, State, frozenlist @@ -10,7 +10,16 @@ def test_basic_initializes_with_arguments() -> None: class Selection(StrEnum): A = "A" - B = "A" + B = "B" + + class Nes(State): + val: str + + class TypedValues(TypedDict): + val: str + mis: int | Missing + req: Required[Nes] + nreq: NotRequired[bool] @runtime_checkable class Proto(Protocol): @@ -32,6 +41,7 @@ class Basics(State): function: Callable[[], None] proto: Proto selection: Selection + typeddict: TypedValues basic = Basics( uuid=uuid4(), @@ -49,6 +59,11 @@ class Basics(State): function=lambda: None, proto=lambda: None, selection=Selection.A, + typeddict={ + "val": "ok", + "mis": 42, + "req": Nes(val="ok"), + }, ) assert basic.string == "string" assert basic.literal == "A" @@ -60,6 +75,11 @@ class Basics(State): assert basic.optional == "optional" assert basic.none is None assert basic.selection == Selection.A + assert basic.typeddict == TypedValues( + val="ok", + mis=42, + req=Nes(val="ok"), + ) assert callable(basic.function) assert isinstance(basic.proto, Proto)