Skip to content

Commit

Permalink
Reduce number of required type params by using Sequence instead of …
Browse files Browse the repository at this point in the history
…`Collection`
  • Loading branch information
4c0n committed Feb 5, 2025
1 parent af19831 commit 50d58f7
Show file tree
Hide file tree
Showing 13 changed files with 144 additions and 166 deletions.
46 changes: 21 additions & 25 deletions meldingen_core/actions/attachment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Collection
from collections.abc import Sequence
from enum import StrEnum
from typing import AsyncIterator, Generic, TypeVar

Expand All @@ -14,16 +14,14 @@
from meldingen_core.validators import BaseMediaTypeIntegrityValidator, BaseMediaTypeValidator

A = TypeVar("A", bound=Attachment)
A_co = TypeVar("A_co", bound=Attachment, covariant=True)
M = TypeVar("M", bound=Melding)
M_co = TypeVar("M_co", bound=Melding, covariant=True)


class UploadAttachmentAction(Generic[A, A_co, M, M_co]):
class UploadAttachmentAction(Generic[A, M]):
_create_attachment: BaseAttachmentFactory[A, M]
_attachment_repository: BaseAttachmentRepository[A, A_co]
_attachment_repository: BaseAttachmentRepository[A]
_filesystem: Filesystem
_verify_token: TokenVerifier[M, M_co]
_verify_token: TokenVerifier[M]
_base_directory: str
_validate_media_type: BaseMediaTypeValidator
_validate_media_type_integrity: BaseMediaTypeIntegrityValidator
Expand All @@ -32,8 +30,8 @@ class UploadAttachmentAction(Generic[A, A_co, M, M_co]):
def __init__(
self,
attachment_factory: BaseAttachmentFactory[A, M],
attachment_repository: BaseAttachmentRepository[A, A_co],
token_verifier: TokenVerifier[M, M_co],
attachment_repository: BaseAttachmentRepository[A],
token_verifier: TokenVerifier[M],
media_type_validator: BaseMediaTypeValidator,
media_type_integrity_validator: BaseMediaTypeIntegrityValidator,
ingestor: BaseIngestor[A],
Expand Down Expand Up @@ -74,15 +72,15 @@ class AttachmentTypes(StrEnum):
THUMBNAIL = "thumbnail"


class DownloadAttachmentAction(Generic[A, A_co, M, M_co]):
_verify_token: TokenVerifier[M, M_co]
_attachment_repository: BaseAttachmentRepository[A, A_co]
class DownloadAttachmentAction(Generic[A, M]):
_verify_token: TokenVerifier[M]
_attachment_repository: BaseAttachmentRepository[A]
_filesystem: Filesystem

def __init__(
self,
token_verifier: TokenVerifier[M, M_co],
attachment_repository: BaseAttachmentRepository[A, A_co],
token_verifier: TokenVerifier[M],
attachment_repository: BaseAttachmentRepository[A],
filesystem: Filesystem,
):
self._verify_token = token_verifier
Expand Down Expand Up @@ -118,31 +116,29 @@ async def __call__(
raise NotFoundException("File not found") from exception


class ListAttachmentsAction(Generic[A, A_co, M, M_co]):
_verify_token: TokenVerifier[M, M_co]
_attachment_repository: BaseAttachmentRepository[A, A_co]
class ListAttachmentsAction(Generic[A, M]):
_verify_token: TokenVerifier[M]
_attachment_repository: BaseAttachmentRepository[A]

def __init__(
self, token_verifier: TokenVerifier[M, M_co], attachment_repository: BaseAttachmentRepository[A, A_co]
):
def __init__(self, token_verifier: TokenVerifier[M], attachment_repository: BaseAttachmentRepository[A]):
self._verify_token = token_verifier
self._attachment_repository = attachment_repository

async def __call__(self, melding_id: int, token: str) -> Collection[A_co]:
async def __call__(self, melding_id: int, token: str) -> Sequence[A]:
await self._verify_token(melding_id, token)

return await self._attachment_repository.find_by_melding(melding_id)


class DeleteAttachmentAction(Generic[A, A_co, M, M_co]):
_verify_token: TokenVerifier[M, M_co]
_attachment_repository: BaseAttachmentRepository[A, A_co]
class DeleteAttachmentAction(Generic[A, M]):
_verify_token: TokenVerifier[M]
_attachment_repository: BaseAttachmentRepository[A]
_filesystem: Filesystem

def __init__(
self,
token_verifier: TokenVerifier[M, M_co],
attachment_repository: BaseAttachmentRepository[A, A_co],
token_verifier: TokenVerifier[M],
attachment_repository: BaseAttachmentRepository[A],
filesystem: Filesystem,
):
self._verify_token = token_verifier
Expand Down
21 changes: 10 additions & 11 deletions meldingen_core/actions/base.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,45 @@
from collections.abc import Collection
from collections.abc import Sequence
from typing import Any, Generic, TypeVar

from meldingen_core import SortingDirection
from meldingen_core.exceptions import NotFoundException
from meldingen_core.repositories import BaseRepository

T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)


class BaseCRUDAction(Generic[T, T_co]):
_repository: BaseRepository[T, T_co]
class BaseCRUDAction(Generic[T]):
_repository: BaseRepository[T]

def __init__(self, repository: BaseRepository[T, T_co]) -> None:
def __init__(self, repository: BaseRepository[T]) -> None:
self._repository = repository


class BaseCreateAction(BaseCRUDAction[T, T_co]):
class BaseCreateAction(BaseCRUDAction[T]):
async def __call__(self, obj: T) -> None:
await self._repository.save(obj)


class BaseRetrieveAction(BaseCRUDAction[T, T_co]):
class BaseRetrieveAction(BaseCRUDAction[T]):
async def __call__(self, pk: int) -> T | None:
return await self._repository.retrieve(pk=pk)


class BaseListAction(BaseCRUDAction[T, T_co]):
class BaseListAction(BaseCRUDAction[T]):
async def __call__(
self,
*,
limit: int | None = None,
offset: int | None = None,
sort_attribute_name: str | None = None,
sort_direction: SortingDirection | None = None,
) -> Collection[T_co]:
) -> Sequence[T]:
return await self._repository.list(
limit=limit, offset=offset, sort_attribute_name=sort_attribute_name, sort_direction=sort_direction
)


class BaseUpdateAction(BaseCRUDAction[T, T_co]):
class BaseUpdateAction(BaseCRUDAction[T]):
async def __call__(self, pk: int, values: dict[str, Any]) -> T:
obj = await self._repository.retrieve(pk=pk)
if obj is None:
Expand All @@ -54,6 +53,6 @@ async def __call__(self, pk: int, values: dict[str, Any]) -> T:
return obj


class BaseDeleteAction(BaseCRUDAction[T, T_co]):
class BaseDeleteAction(BaseCRUDAction[T]):
async def __call__(self, pk: int) -> None:
await self._repository.delete(pk=pk)
11 changes: 5 additions & 6 deletions meldingen_core/actions/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,18 @@
from meldingen_core.models import Classification

T = TypeVar("T", bound=Classification)
T_co = TypeVar("T_co", covariant=True, bound=Classification)


class ClassificationCreateAction(BaseCreateAction[T, T_co]): ...
class ClassificationCreateAction(BaseCreateAction[T]): ...


class ClassificationListAction(BaseListAction[T, T_co]): ...
class ClassificationListAction(BaseListAction[T]): ...


class ClassificationRetrieveAction(BaseRetrieveAction[T, T_co]): ...
class ClassificationRetrieveAction(BaseRetrieveAction[T]): ...


class ClassificationUpdateAction(BaseUpdateAction[T, T_co]): ...
class ClassificationUpdateAction(BaseUpdateAction[T]): ...


class ClassificationDeleteAction(BaseDeleteAction[T, T_co]): ...
class ClassificationDeleteAction(BaseDeleteAction[T]): ...
53 changes: 26 additions & 27 deletions meldingen_core/actions/melding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
log = logging.getLogger(__name__)

T = TypeVar("T", bound=Melding)
T_co = TypeVar("T_co", covariant=True, bound=Melding)


class MeldingCreateAction(BaseCreateAction[T, T_co]):
class MeldingCreateAction(BaseCreateAction[T]):
"""Action that stores a melding."""

_classify: Classifier
Expand All @@ -27,7 +26,7 @@ class MeldingCreateAction(BaseCreateAction[T, T_co]):

def __init__(
self,
repository: BaseRepository[T, T_co],
repository: BaseRepository[T],
classifier: Classifier,
state_machine: BaseMeldingStateMachine[T],
token_generator: BaseTokenGenerator,
Expand Down Expand Up @@ -56,25 +55,25 @@ async def __call__(self, obj: T) -> None:
await self._repository.save(obj)


class MeldingListAction(BaseListAction[T, T_co]):
class MeldingListAction(BaseListAction[T]):
"""Action that retrieves a list of meldingen."""


class MeldingRetrieveAction(BaseRetrieveAction[T, T_co]):
class MeldingRetrieveAction(BaseRetrieveAction[T]):
"""Action that retrieves a melding."""


class MeldingUpdateAction(BaseCRUDAction[T, T_co]):
class MeldingUpdateAction(BaseCRUDAction[T]):
"""Action that updates the melding and reclassifies it"""

_verify_token: TokenVerifier[T, T_co]
_verify_token: TokenVerifier[T]
_classify: Classifier
_state_machine: BaseMeldingStateMachine[T]

def __init__(
self,
repository: BaseRepository[T, T_co],
token_verifier: TokenVerifier[T, T_co],
repository: BaseRepository[T],
token_verifier: TokenVerifier[T],
classifier: Classifier,
state_machine: BaseMeldingStateMachine[T],
) -> None:
Expand All @@ -98,15 +97,15 @@ async def __call__(self, pk: int, values: dict[str, Any], token: str) -> T:
return melding


class MeldingAddContactInfoAction(BaseCRUDAction[T, T_co]):
class MeldingAddContactInfoAction(BaseCRUDAction[T]):
"""Action that adds contact information to a melding."""

_verify_token: TokenVerifier[T, T_co]
_verify_token: TokenVerifier[T]

def __init__(
self,
repository: BaseMeldingRepository[T, T_co],
token_verifier: TokenVerifier[T, T_co],
repository: BaseMeldingRepository[T],
token_verifier: TokenVerifier[T],
) -> None:
super().__init__(repository)
self._verify_token = token_verifier
Expand All @@ -122,19 +121,19 @@ async def __call__(self, pk: int, phone: str | None, email: str | None, token: s
return melding


class BaseStateTransitionAction(Generic[T, T_co], metaclass=ABCMeta):
class BaseStateTransitionAction(Generic[T], metaclass=ABCMeta):
"""
This action covers transitions that do not require the melding's token to be verified.
Typically these actions are performed by authenticated users.
"""

_state_machine: BaseMeldingStateMachine[T]
_repository: BaseMeldingRepository[T, T_co]
_repository: BaseMeldingRepository[T]

def __init__(
self,
state_machine: BaseMeldingStateMachine[T],
repository: BaseMeldingRepository[T, T_co],
repository: BaseMeldingRepository[T],
):
self._state_machine = state_machine
self._repository = repository
Expand All @@ -154,21 +153,21 @@ async def __call__(self, melding_id: int) -> T:
return melding


class BaseMeldingFormStateTransitionAction(Generic[T, T_co], metaclass=ABCMeta):
class BaseMeldingFormStateTransitionAction(Generic[T], metaclass=ABCMeta):
"""
This action covers transitions that require the melding's token to be verified.
This is the case for unauthenticated state transitions where a user submits a melding.
"""

_state_machine: BaseMeldingStateMachine[T]
_repository: BaseMeldingRepository[T, T_co]
_verify_token: TokenVerifier[T, T_co]
_repository: BaseMeldingRepository[T]
_verify_token: TokenVerifier[T]

def __init__(
self,
state_machine: BaseMeldingStateMachine[T],
repository: BaseMeldingRepository[T, T_co],
token_verifier: TokenVerifier[T, T_co],
repository: BaseMeldingRepository[T],
token_verifier: TokenVerifier[T],
):
self._state_machine = state_machine
self._repository = repository
Expand All @@ -187,37 +186,37 @@ async def __call__(self, melding_id: int, token: str) -> T:
return melding


class MeldingAnswerQuestionsAction(BaseMeldingFormStateTransitionAction[T, T_co]):
class MeldingAnswerQuestionsAction(BaseMeldingFormStateTransitionAction[T]):
@property
def transition_name(self) -> str:
return MeldingTransitions.ANSWER_QUESTIONS


class MeldingAddAttachmentsAction(BaseMeldingFormStateTransitionAction[T, T_co]):
class MeldingAddAttachmentsAction(BaseMeldingFormStateTransitionAction[T]):
@property
def transition_name(self) -> str:
return MeldingTransitions.ADD_ATTACHMENTS


class MeldingSubmitLocationAction(BaseMeldingFormStateTransitionAction[T, T_co]):
class MeldingSubmitLocationAction(BaseMeldingFormStateTransitionAction[T]):
@property
def transition_name(self) -> str:
return MeldingTransitions.SUBMIT_LOCATION


class MeldingContactInfoAddedAction(BaseMeldingFormStateTransitionAction[T, T_co]):
class MeldingContactInfoAddedAction(BaseMeldingFormStateTransitionAction[T]):
@property
def transition_name(self) -> str:
return MeldingTransitions.ADD_CONTACT_INFO


class MeldingProcessAction(BaseStateTransitionAction[T, T_co]):
class MeldingProcessAction(BaseStateTransitionAction[T]):
@property
def transition_name(self) -> str:
return MeldingTransitions.PROCESS


class MeldingCompleteAction(BaseStateTransitionAction[T, T_co]):
class MeldingCompleteAction(BaseStateTransitionAction[T]):
@property
def transition_name(self) -> str:
return MeldingTransitions.COMPLETE
11 changes: 5 additions & 6 deletions meldingen_core/actions/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,23 @@
from meldingen_core.models import User

T = TypeVar("T", bound=User)
T_co = TypeVar("T_co", covariant=True, bound=User)


class UserCreateAction(BaseCreateAction[T, T_co]):
class UserCreateAction(BaseCreateAction[T]):
"""Action that add a user."""


class UserUpdateAction(BaseUpdateAction[T, T_co]):
class UserUpdateAction(BaseUpdateAction[T]):
"""Action that updates a user."""


class UserListAction(BaseListAction[T, T_co]):
class UserListAction(BaseListAction[T]):
"""Action that retrieves a list of users."""


class UserRetrieveAction(BaseRetrieveAction[T, T_co]):
class UserRetrieveAction(BaseRetrieveAction[T]):
"""Action that retrieves a user."""


class UserDeleteAction(BaseDeleteAction[T, T_co]):
class UserDeleteAction(BaseDeleteAction[T]):
"""Action that deletes a user."""
Loading

0 comments on commit 50d58f7

Please sign in to comment.