Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce number of required type params by using Sequence instead of … #130

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 21 additions & 25 deletions meldingen_core/actions/attachment.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Collection
from collections.abc import Sequence
from enum import StrEnum
from typing import AsyncIterator, Generic, TypeVar

Expand All @@ -14,16 +14,14 @@
from meldingen_core.validators import BaseMediaTypeIntegrityValidator, BaseMediaTypeValidator

A = TypeVar("A", bound=Attachment)
A_co = TypeVar("A_co", bound=Attachment, covariant=True)
M = TypeVar("M", bound=Melding)
M_co = TypeVar("M_co", bound=Melding, covariant=True)


class UploadAttachmentAction(Generic[A, A_co, M, M_co]):
class UploadAttachmentAction(Generic[A, M]):
_create_attachment: BaseAttachmentFactory[A, M]
_attachment_repository: BaseAttachmentRepository[A, A_co]
_attachment_repository: BaseAttachmentRepository[A]
_filesystem: Filesystem
_verify_token: TokenVerifier[M, M_co]
_verify_token: TokenVerifier[M]
_base_directory: str
_validate_media_type: BaseMediaTypeValidator
_validate_media_type_integrity: BaseMediaTypeIntegrityValidator
Expand All @@ -32,8 +30,8 @@ class UploadAttachmentAction(Generic[A, A_co, M, M_co]):
def __init__(
self,
attachment_factory: BaseAttachmentFactory[A, M],
attachment_repository: BaseAttachmentRepository[A, A_co],
token_verifier: TokenVerifier[M, M_co],
attachment_repository: BaseAttachmentRepository[A],
token_verifier: TokenVerifier[M],
media_type_validator: BaseMediaTypeValidator,
media_type_integrity_validator: BaseMediaTypeIntegrityValidator,
ingestor: BaseIngestor[A],
Expand Down Expand Up @@ -74,15 +72,15 @@ class AttachmentTypes(StrEnum):
THUMBNAIL = "thumbnail"


class DownloadAttachmentAction(Generic[A, A_co, M, M_co]):
_verify_token: TokenVerifier[M, M_co]
_attachment_repository: BaseAttachmentRepository[A, A_co]
class DownloadAttachmentAction(Generic[A, M]):
_verify_token: TokenVerifier[M]
_attachment_repository: BaseAttachmentRepository[A]
_filesystem: Filesystem

def __init__(
self,
token_verifier: TokenVerifier[M, M_co],
attachment_repository: BaseAttachmentRepository[A, A_co],
token_verifier: TokenVerifier[M],
attachment_repository: BaseAttachmentRepository[A],
filesystem: Filesystem,
):
self._verify_token = token_verifier
Expand Down Expand Up @@ -118,31 +116,29 @@ async def __call__(
raise NotFoundException("File not found") from exception


class ListAttachmentsAction(Generic[A, A_co, M, M_co]):
_verify_token: TokenVerifier[M, M_co]
_attachment_repository: BaseAttachmentRepository[A, A_co]
class ListAttachmentsAction(Generic[A, M]):
_verify_token: TokenVerifier[M]
_attachment_repository: BaseAttachmentRepository[A]

def __init__(
self, token_verifier: TokenVerifier[M, M_co], attachment_repository: BaseAttachmentRepository[A, A_co]
):
def __init__(self, token_verifier: TokenVerifier[M], attachment_repository: BaseAttachmentRepository[A]):
self._verify_token = token_verifier
self._attachment_repository = attachment_repository

async def __call__(self, melding_id: int, token: str) -> Collection[A_co]:
async def __call__(self, melding_id: int, token: str) -> Sequence[A]:
await self._verify_token(melding_id, token)

return await self._attachment_repository.find_by_melding(melding_id)


class DeleteAttachmentAction(Generic[A, A_co, M, M_co]):
_verify_token: TokenVerifier[M, M_co]
_attachment_repository: BaseAttachmentRepository[A, A_co]
class DeleteAttachmentAction(Generic[A, M]):
_verify_token: TokenVerifier[M]
_attachment_repository: BaseAttachmentRepository[A]
_filesystem: Filesystem

def __init__(
self,
token_verifier: TokenVerifier[M, M_co],
attachment_repository: BaseAttachmentRepository[A, A_co],
token_verifier: TokenVerifier[M],
attachment_repository: BaseAttachmentRepository[A],
filesystem: Filesystem,
):
self._verify_token = token_verifier
Expand Down
21 changes: 10 additions & 11 deletions meldingen_core/actions/base.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,45 @@
from collections.abc import Collection
from collections.abc import Sequence
from typing import Any, Generic, TypeVar

from meldingen_core import SortingDirection
from meldingen_core.exceptions import NotFoundException
from meldingen_core.repositories import BaseRepository

T = TypeVar("T")
T_co = TypeVar("T_co", covariant=True)


class BaseCRUDAction(Generic[T, T_co]):
_repository: BaseRepository[T, T_co]
class BaseCRUDAction(Generic[T]):
_repository: BaseRepository[T]

def __init__(self, repository: BaseRepository[T, T_co]) -> None:
def __init__(self, repository: BaseRepository[T]) -> None:
self._repository = repository


class BaseCreateAction(BaseCRUDAction[T, T_co]):
class BaseCreateAction(BaseCRUDAction[T]):
async def __call__(self, obj: T) -> None:
await self._repository.save(obj)


class BaseRetrieveAction(BaseCRUDAction[T, T_co]):
class BaseRetrieveAction(BaseCRUDAction[T]):
async def __call__(self, pk: int) -> T | None:
return await self._repository.retrieve(pk=pk)


class BaseListAction(BaseCRUDAction[T, T_co]):
class BaseListAction(BaseCRUDAction[T]):
async def __call__(
self,
*,
limit: int | None = None,
offset: int | None = None,
sort_attribute_name: str | None = None,
sort_direction: SortingDirection | None = None,
) -> Collection[T_co]:
) -> Sequence[T]:
return await self._repository.list(
limit=limit, offset=offset, sort_attribute_name=sort_attribute_name, sort_direction=sort_direction
)


class BaseUpdateAction(BaseCRUDAction[T, T_co]):
class BaseUpdateAction(BaseCRUDAction[T]):
async def __call__(self, pk: int, values: dict[str, Any]) -> T:
obj = await self._repository.retrieve(pk=pk)
if obj is None:
Expand All @@ -54,6 +53,6 @@ async def __call__(self, pk: int, values: dict[str, Any]) -> T:
return obj


class BaseDeleteAction(BaseCRUDAction[T, T_co]):
class BaseDeleteAction(BaseCRUDAction[T]):
async def __call__(self, pk: int) -> None:
await self._repository.delete(pk=pk)
11 changes: 5 additions & 6 deletions meldingen_core/actions/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,18 @@
from meldingen_core.models import Classification

T = TypeVar("T", bound=Classification)
T_co = TypeVar("T_co", covariant=True, bound=Classification)


class ClassificationCreateAction(BaseCreateAction[T, T_co]): ...
class ClassificationCreateAction(BaseCreateAction[T]): ...


class ClassificationListAction(BaseListAction[T, T_co]): ...
class ClassificationListAction(BaseListAction[T]): ...


class ClassificationRetrieveAction(BaseRetrieveAction[T, T_co]): ...
class ClassificationRetrieveAction(BaseRetrieveAction[T]): ...


class ClassificationUpdateAction(BaseUpdateAction[T, T_co]): ...
class ClassificationUpdateAction(BaseUpdateAction[T]): ...


class ClassificationDeleteAction(BaseDeleteAction[T, T_co]): ...
class ClassificationDeleteAction(BaseDeleteAction[T]): ...
53 changes: 26 additions & 27 deletions meldingen_core/actions/melding.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
log = logging.getLogger(__name__)

T = TypeVar("T", bound=Melding)
T_co = TypeVar("T_co", covariant=True, bound=Melding)


class MeldingCreateAction(BaseCreateAction[T, T_co]):
class MeldingCreateAction(BaseCreateAction[T]):
"""Action that stores a melding."""

_classify: Classifier
Expand All @@ -27,7 +26,7 @@ class MeldingCreateAction(BaseCreateAction[T, T_co]):

def __init__(
self,
repository: BaseRepository[T, T_co],
repository: BaseRepository[T],
classifier: Classifier,
state_machine: BaseMeldingStateMachine[T],
token_generator: BaseTokenGenerator,
Expand Down Expand Up @@ -56,25 +55,25 @@ async def __call__(self, obj: T) -> None:
await self._repository.save(obj)


class MeldingListAction(BaseListAction[T, T_co]):
class MeldingListAction(BaseListAction[T]):
"""Action that retrieves a list of meldingen."""


class MeldingRetrieveAction(BaseRetrieveAction[T, T_co]):
class MeldingRetrieveAction(BaseRetrieveAction[T]):
"""Action that retrieves a melding."""


class MeldingUpdateAction(BaseCRUDAction[T, T_co]):
class MeldingUpdateAction(BaseCRUDAction[T]):
"""Action that updates the melding and reclassifies it"""

_verify_token: TokenVerifier[T, T_co]
_verify_token: TokenVerifier[T]
_classify: Classifier
_state_machine: BaseMeldingStateMachine[T]

def __init__(
self,
repository: BaseRepository[T, T_co],
token_verifier: TokenVerifier[T, T_co],
repository: BaseRepository[T],
token_verifier: TokenVerifier[T],
classifier: Classifier,
state_machine: BaseMeldingStateMachine[T],
) -> None:
Expand All @@ -98,15 +97,15 @@ async def __call__(self, pk: int, values: dict[str, Any], token: str) -> T:
return melding


class MeldingAddContactInfoAction(BaseCRUDAction[T, T_co]):
class MeldingAddContactInfoAction(BaseCRUDAction[T]):
"""Action that adds contact information to a melding."""

_verify_token: TokenVerifier[T, T_co]
_verify_token: TokenVerifier[T]

def __init__(
self,
repository: BaseMeldingRepository[T, T_co],
token_verifier: TokenVerifier[T, T_co],
repository: BaseMeldingRepository[T],
token_verifier: TokenVerifier[T],
) -> None:
super().__init__(repository)
self._verify_token = token_verifier
Expand All @@ -122,19 +121,19 @@ async def __call__(self, pk: int, phone: str | None, email: str | None, token: s
return melding


class BaseStateTransitionAction(Generic[T, T_co], metaclass=ABCMeta):
class BaseStateTransitionAction(Generic[T], metaclass=ABCMeta):
"""
This action covers transitions that do not require the melding's token to be verified.
Typically these actions are performed by authenticated users.
"""

_state_machine: BaseMeldingStateMachine[T]
_repository: BaseMeldingRepository[T, T_co]
_repository: BaseMeldingRepository[T]

def __init__(
self,
state_machine: BaseMeldingStateMachine[T],
repository: BaseMeldingRepository[T, T_co],
repository: BaseMeldingRepository[T],
):
self._state_machine = state_machine
self._repository = repository
Expand All @@ -154,21 +153,21 @@ async def __call__(self, melding_id: int) -> T:
return melding


class BaseMeldingFormStateTransitionAction(Generic[T, T_co], metaclass=ABCMeta):
class BaseMeldingFormStateTransitionAction(Generic[T], metaclass=ABCMeta):
"""
This action covers transitions that require the melding's token to be verified.
This is the case for unauthenticated state transitions where a user submits a melding.
"""

_state_machine: BaseMeldingStateMachine[T]
_repository: BaseMeldingRepository[T, T_co]
_verify_token: TokenVerifier[T, T_co]
_repository: BaseMeldingRepository[T]
_verify_token: TokenVerifier[T]

def __init__(
self,
state_machine: BaseMeldingStateMachine[T],
repository: BaseMeldingRepository[T, T_co],
token_verifier: TokenVerifier[T, T_co],
repository: BaseMeldingRepository[T],
token_verifier: TokenVerifier[T],
):
self._state_machine = state_machine
self._repository = repository
Expand All @@ -187,37 +186,37 @@ async def __call__(self, melding_id: int, token: str) -> T:
return melding


class MeldingAnswerQuestionsAction(BaseMeldingFormStateTransitionAction[T, T_co]):
class MeldingAnswerQuestionsAction(BaseMeldingFormStateTransitionAction[T]):
@property
def transition_name(self) -> str:
return MeldingTransitions.ANSWER_QUESTIONS


class MeldingAddAttachmentsAction(BaseMeldingFormStateTransitionAction[T, T_co]):
class MeldingAddAttachmentsAction(BaseMeldingFormStateTransitionAction[T]):
@property
def transition_name(self) -> str:
return MeldingTransitions.ADD_ATTACHMENTS


class MeldingSubmitLocationAction(BaseMeldingFormStateTransitionAction[T, T_co]):
class MeldingSubmitLocationAction(BaseMeldingFormStateTransitionAction[T]):
@property
def transition_name(self) -> str:
return MeldingTransitions.SUBMIT_LOCATION


class MeldingContactInfoAddedAction(BaseMeldingFormStateTransitionAction[T, T_co]):
class MeldingContactInfoAddedAction(BaseMeldingFormStateTransitionAction[T]):
@property
def transition_name(self) -> str:
return MeldingTransitions.ADD_CONTACT_INFO


class MeldingProcessAction(BaseStateTransitionAction[T, T_co]):
class MeldingProcessAction(BaseStateTransitionAction[T]):
@property
def transition_name(self) -> str:
return MeldingTransitions.PROCESS


class MeldingCompleteAction(BaseStateTransitionAction[T, T_co]):
class MeldingCompleteAction(BaseStateTransitionAction[T]):
@property
def transition_name(self) -> str:
return MeldingTransitions.COMPLETE
11 changes: 5 additions & 6 deletions meldingen_core/actions/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,23 @@
from meldingen_core.models import User

T = TypeVar("T", bound=User)
T_co = TypeVar("T_co", covariant=True, bound=User)


class UserCreateAction(BaseCreateAction[T, T_co]):
class UserCreateAction(BaseCreateAction[T]):
"""Action that add a user."""


class UserUpdateAction(BaseUpdateAction[T, T_co]):
class UserUpdateAction(BaseUpdateAction[T]):
"""Action that updates a user."""


class UserListAction(BaseListAction[T, T_co]):
class UserListAction(BaseListAction[T]):
"""Action that retrieves a list of users."""


class UserRetrieveAction(BaseRetrieveAction[T, T_co]):
class UserRetrieveAction(BaseRetrieveAction[T]):
"""Action that retrieves a user."""


class UserDeleteAction(BaseDeleteAction[T, T_co]):
class UserDeleteAction(BaseDeleteAction[T]):
"""Action that deletes a user."""
Loading