Skip to content

Commit

Permalink
Allow disposables within scope contexts
Browse files Browse the repository at this point in the history
  • Loading branch information
KaQuMiQ authored Oct 28, 2024
1 parent a2814c5 commit b55601b
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 27 deletions.
4 changes: 3 additions & 1 deletion src/haiway/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from haiway.context import (
Disposable,
Disposables,
MissingContext,
MissingState,
ScopeMetrics,
ctx,
)
from haiway.helpers import Disposable, Disposables, asynchronous, cache, retry, throttle, timeout
from haiway.helpers import asynchronous, cache, retry, throttle, timeout
from haiway.state import State
from haiway.types import (
MISSING,
Expand Down
3 changes: 3 additions & 0 deletions src/haiway/context/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from haiway.context.access import ctx
from haiway.context.disposables import Disposable, Disposables
from haiway.context.metrics import ScopeMetrics
from haiway.context.types import MissingContext, MissingState

__all__ = [
"ctx",
"Disposable",
"Disposables",
"MissingContext",
"MissingState",
"ScopeMetrics",
Expand Down
102 changes: 79 additions & 23 deletions src/haiway/context/access.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from collections.abc import (
Callable,
Coroutine,
Iterable,
)
from logging import Logger
from types import TracebackType
from typing import Any, final

from haiway.context.disposables import Disposable, Disposables
from haiway.context.metrics import MetricsContext, ScopeMetrics
from haiway.context.state import StateContext
from haiway.context.tasks import TaskGroupContext
Expand All @@ -23,75 +25,113 @@

@final
class ScopeContext:
def __init__(
def __init__( # noqa: PLR0913
self,
trace_id: str | None,
name: str,
logger: Logger | None,
state: tuple[State, ...],
disposables: Disposables | None,
task_group: TaskGroupContext,
state: StateContext,
metrics: MetricsContext,
completion: Callable[[ScopeMetrics], Coroutine[None, None, None]] | None,
) -> None:
self._task_group: TaskGroupContext = task_group
self._state: StateContext = state
self._metrics: MetricsContext = metrics
self._logger: Logger | None = logger
self._trace_id: str | None = trace_id
self._name: str = name
self._state_context: StateContext
self._state: tuple[State, ...] = state
self._disposables: Disposables | None = disposables
self._metrics_context: MetricsContext
self._completion: Callable[[ScopeMetrics], Coroutine[None, None, None]] | None = completion

freeze(self)

def __enter__(self) -> None:
assert self._completion is None, "Can't enter synchronous context with completion" # nosec: B101
assert self._disposables is None, "Can't enter synchronous context with disposables" # nosec: B101

self._state_context = StateContext.updated(self._state)
self._metrics_context = MetricsContext.scope(
self._name,
logger=self._logger,
trace_id=self._trace_id,
)

self._state.__enter__()
self._metrics.__enter__()
self._state_context.__enter__()
self._metrics_context.__enter__()

def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self._metrics.__exit__(
self._metrics_context.__exit__(
exc_type=exc_type,
exc_val=exc_val,
exc_tb=exc_tb,
)

self._state.__exit__(
self._state_context.__exit__(
exc_type=exc_type,
exc_val=exc_val,
exc_tb=exc_tb,
)

async def __aenter__(self) -> None:
self._state.__enter__()
self._metrics.__enter__()
await self._task_group.__aenter__()

if self._disposables:
self._state_context = StateContext.updated(
(*self._state, *await self._disposables.__aenter__())
)

else:
self._state_context = StateContext.updated(self._state)

self._metrics_context = MetricsContext.scope(
self._name,
logger=self._logger,
trace_id=self._trace_id,
)

self._state_context.__enter__()
self._metrics_context.__enter__()

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
if self._disposables:
await self._disposables.__aexit__(
exc_type=exc_type,
exc_val=exc_val,
exc_tb=exc_tb,
)

await self._task_group.__aexit__(
exc_type=exc_type,
exc_val=exc_val,
exc_tb=exc_tb,
)

self._metrics.__exit__(
self._metrics_context.__exit__(
exc_type=exc_type,
exc_val=exc_val,
exc_tb=exc_tb,
)

self._state.__exit__(
self._state_context.__exit__(
exc_type=exc_type,
exc_val=exc_val,
exc_tb=exc_tb,
)

if completion := self._completion:
await completion(self._metrics._metrics) # pyright: ignore[reportPrivateUsage]
await completion(self._metrics_context._metrics) # pyright: ignore[reportPrivateUsage]


@final
Expand All @@ -101,6 +141,7 @@ def scope(
name: str,
/,
*state: State,
disposables: Disposables | Iterable[Disposable] | None = None,
logger: Logger | None = None,
trace_id: str | None = None,
completion: Callable[[ScopeMetrics], Coroutine[None, None, None]] | None = None,
Expand All @@ -114,9 +155,14 @@ def scope(
name: Value
name of the scope context
*state: State
state propagated within the scope context, will be merged with current if any\
by replacing current with provided on conflict
*state: State | Disposable
state propagated within the scope context, will be merged with current state by\
replacing current with provided on conflict.
disposables: Disposables | list[Disposable] | None
disposables consumed within the context when entered. Produced state will automatically\
be added to the scope state. Using asynchronous context is required if any disposables\
were provided.
logger: Logger | None
logger used within the scope context, when not provided current logger will be used\
Expand All @@ -138,14 +184,24 @@ def scope(
context object intended to enter context manager with it
"""

resolved_disposables: Disposables | None
match disposables:
case None:
resolved_disposables = None

case Disposables() as disposables:
resolved_disposables = disposables

case iterable:
resolved_disposables = Disposables(*iterable)

return ScopeContext(
trace_id=trace_id,
name=name,
logger=logger,
state=state,
disposables=resolved_disposables,
task_group=TaskGroupContext(),
metrics=MetricsContext.scope(
name,
logger=logger,
trace_id=trace_id,
),
state=StateContext.updated(state),
completion=completion,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ async def dispose(self) -> None:
elif exceptions:
raise exceptions[0]

def __bool__(self) -> bool:
return len(self._disposables) > 0

async def __aenter__(self) -> Iterable[State]:
return await self.initialize()

Expand Down
3 changes: 0 additions & 3 deletions src/haiway/helpers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
from haiway.helpers.asynchronous import asynchronous
from haiway.helpers.cached import cache
from haiway.helpers.disposables import Disposable, Disposables
from haiway.helpers.retries import retry
from haiway.helpers.throttling import throttle
from haiway.helpers.timeouted import timeout

__all__ = [
"asynchronous",
"cache",
"Disposable",
"Disposables",
"retry",
"throttle",
"timeout",
Expand Down

0 comments on commit b55601b

Please sign in to comment.