Skip to content

Commit

Permalink
Allow attribute default factory
Browse files Browse the repository at this point in the history
  • Loading branch information
KaQuMiQ committed Feb 10, 2025
1 parent 09247d4 commit cff12f8
Show file tree
Hide file tree
Showing 12 changed files with 279 additions and 102 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "haiway"
description = "Framework for dependency injection and state management within structured concurrency model."
version = "0.9.1"
version = "0.10.0"
readme = "README.md"
maintainers = [
{ name = "Kacper Kaliński", email = "[email protected]" },
Expand Down
6 changes: 6 additions & 0 deletions src/haiway/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from haiway.state import AttributePath, AttributeRequirement, State
from haiway.types import (
MISSING,
Default,
DefaultValue,
Missing,
frozenlist,
is_missing,
Expand All @@ -39,6 +41,7 @@
always,
as_dict,
as_list,
as_set,
as_tuple,
async_always,
async_noop,
Expand All @@ -59,6 +62,8 @@
"AsyncQueue",
"AttributePath",
"AttributeRequirement",
"Default",
"DefaultValue",
"Disposable",
"Disposables",
"MetricsContext",
Expand All @@ -78,6 +83,7 @@
"always",
"as_dict",
"as_list",
"as_set",
"as_tuple",
"async_always",
"async_noop",
Expand Down
55 changes: 44 additions & 11 deletions src/haiway/state/structure.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import typing
from collections.abc import Mapping
from collections.abc import Callable, Mapping
from types import EllipsisType, GenericAlias
from typing import (
Any,
Expand All @@ -10,48 +10,69 @@
cast,
dataclass_transform,
final,
overload,
)
from weakref import WeakValueDictionary

from haiway.state.attributes import AttributeAnnotation, attribute_annotations
from haiway.state.path import AttributePath
from haiway.state.validation import (
AttributeValidation,
AttributeValidator,
)
from haiway.types import MISSING, Missing, not_missing
from haiway.state.validation import AttributeValidation, AttributeValidator
from haiway.types import MISSING, DefaultValue, Missing, not_missing

__all__ = [
"State",
]


@overload
def Default[Value](
value: Value,
/,
) -> Value: ...


@overload
def Default[Value](
*,
factory: Callable[[], Value],
) -> Value: ...


def Default[Value](
value: Value | Missing = MISSING,
/,
*,
factory: Callable[[], Value] | Missing = MISSING,
) -> Value: # it is actually a DefaultValue, but type checker has to be fooled
return cast(Value, DefaultValue(value, factory=factory))


@final
class StateAttribute[Value]:
def __init__(
self,
name: str,
annotation: AttributeAnnotation,
default: Value | Missing,
default: DefaultValue[Value],
validator: AttributeValidation[Value],
) -> None:
self.name: str = name
self.annotation: AttributeAnnotation = annotation
self.default: Value | Missing = default
self.default: DefaultValue[Value] = default
self.validator: AttributeValidation[Value] = validator

def validated(
self,
value: Any | Missing,
/,
) -> Value:
return self.validator(self.default if value is MISSING else value)
return self.validator(self.default() if value is MISSING else value)


@dataclass_transform(
kw_only_default=True,
frozen_default=True,
field_specifiers=(),
field_specifiers=(DefaultValue,),
)
class StateMeta(type):
def __new__(
Expand Down Expand Up @@ -81,7 +102,7 @@ def __new__(
attributes[key] = StateAttribute(
name=key,
annotation=annotation.update_required(default is MISSING),
default=default,
default=_resolve_default(default),
validator=AttributeValidator.of(
annotation,
recursion_guard={
Expand Down Expand Up @@ -187,6 +208,18 @@ def __subclasscheck__( # noqa: C901, PLR0911, PLR0912
return False # we have different base / comparing to not parametrized


def _resolve_default[Value](
value: DefaultValue[Value] | Value | Missing,
) -> DefaultValue[Value]:
if isinstance(value, DefaultValue):
return cast(DefaultValue[Value], value)

return DefaultValue[Value](
value,
factory=MISSING,
)


_types_cache: WeakValueDictionary[
tuple[
Any,
Expand Down
3 changes: 3 additions & 0 deletions src/haiway/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from haiway.types.default import Default, DefaultValue
from haiway.types.frozen import frozenlist
from haiway.types.missing import MISSING, Missing, is_missing, not_missing, when_missing

__all__ = [
"MISSING",
"Default",
"DefaultValue",
"Missing",
"frozenlist",
"is_missing",
Expand Down
108 changes: 108 additions & 0 deletions src/haiway/types/default.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from collections.abc import Callable
from typing import Any, cast, final, overload

from haiway.types.missing import MISSING, Missing, not_missing
from haiway.utils.always import always

__all__ = [
"Default",
"DefaultValue",
]


@final
class DefaultValue[Value]:
@overload
def __init__(
self,
value: Value,
/,
) -> None: ...

@overload
def __init__(
self,
/,
*,
factory: Callable[[], Value],
) -> None: ...

@overload
def __init__(
self,
value: Value | Missing,
/,
*,
factory: Callable[[], Value] | Missing,
) -> None: ...

def __init__(
self,
value: Value | Missing = MISSING,
/,
*,
factory: Callable[[], Value] | Missing = MISSING,
) -> None:
assert ( # nosec: B101
value is MISSING or factory is MISSING
), "Can't specify both default value and factory"

self._value: Callable[[], Value | Missing]
if not_missing(factory):
object.__setattr__(
self,
"_value",
factory,
)

else:
object.__setattr__(
self,
"_value",
always(value),
)

def __call__(self) -> Value | Missing:
return self._value()

def __setattr__(
self,
__name: str,
__value: Any,
) -> None:
raise AttributeError("Missing can't be modified")

def __delattr__(
self,
__name: str,
) -> None:
raise AttributeError("Missing can't be modified")


@overload
def Default[Value](
value: Value,
/,
) -> Value: ...


@overload
def Default[Value](
*,
factory: Callable[[], Value],
) -> Value: ...


def Default[Value](
value: Value | Missing = MISSING,
/,
*,
factory: Callable[[], Value] | Missing = MISSING,
) -> Value: # it is actually a DefaultValue, but type checker has to be fooled most some cases
return cast(
Value,
DefaultValue(
value,
factory=factory,
),
)
1 change: 1 addition & 0 deletions src/haiway/types/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def not_missing[Value](
def when_missing[Value](
check: Value | Missing,
/,
*,
value: Value,
) -> Value:
if check is MISSING:
Expand Down
6 changes: 3 additions & 3 deletions src/haiway/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from haiway.utils.always import always, async_always
from haiway.utils.collections import as_dict, as_list, as_set, as_tuple
from haiway.utils.env import getenv_bool, getenv_float, getenv_int, getenv_str, load_env
from haiway.utils.immutable import freeze
from haiway.utils.freezing import freeze
from haiway.utils.logs import setup_logging
from haiway.utils.mappings import as_dict
from haiway.utils.mimic import mimic_function
from haiway.utils.noop import async_noop, noop
from haiway.utils.queue import AsyncQueue
from haiway.utils.sequences import as_list, as_tuple

__all__ = [
"AsyncQueue",
"always",
"as_dict",
"as_list",
"as_set",
"as_tuple",
"async_always",
"async_noop",
Expand Down
Loading

0 comments on commit cff12f8

Please sign in to comment.