From 47b2baab2f76a39fd58d0125415db1df81a1187d Mon Sep 17 00:00:00 2001 From: hawk-tomy Date: Wed, 11 Sep 2024 03:56:42 +0900 Subject: [PATCH] update: refactor Modal feature --- discord/ext/flow/__init__.py | 3 +- discord/ext/flow/modal.py | 101 ++++++++++++++++++++++++++++------- discord/ext/flow/model.py | 19 ++++--- discord/ext/flow/pages.py | 52 +++++++++++++----- discord/ext/flow/view.py | 30 ++++++++--- example/modal.py | 63 ++++++++++++---------- 6 files changed, 192 insertions(+), 76 deletions(-) diff --git a/discord/ext/flow/__init__.py b/discord/ext/flow/__init__.py index 1f1f66e..1f98598 100644 --- a/discord/ext/flow/__init__.py +++ b/discord/ext/flow/__init__.py @@ -21,7 +21,8 @@ 'Controller', 'TextInput', 'ModalConfig', - 'send_modal', + 'ModalResult', + 'ModalController', 'Paginator', 'paginator', 'Result', diff --git a/discord/ext/flow/modal.py b/discord/ext/flow/modal.py index 3321ce7..5ed3f70 100644 --- a/discord/ext/flow/modal.py +++ b/discord/ext/flow/modal.py @@ -1,14 +1,16 @@ from __future__ import annotations +from asyncio import Future, Task, get_running_loop from dataclasses import dataclass -from typing import TYPE_CHECKING, TypedDict +from typing import TYPE_CHECKING, NamedTuple, TypedDict from discord import Client, Interaction, TextStyle, ui if TYPE_CHECKING: from collections.abc import Sequence -__all__ = ('TextInput', 'ModalConfig', 'send_modal') + +__all__ = ('TextInput', 'ModalConfig', 'ModalResult', 'ModalController') @dataclass @@ -53,10 +55,16 @@ class ModalConfigKWargs(TypedDict, total=False): custom_id: str -class InnerModal(ui.Modal): - results: tuple[str, ...] +class ModalResult(NamedTuple): + """Result of modal. length of texts is same as length of text_inputs in ModalConfig.""" + + texts: tuple[str, ...] interaction: Interaction[Client] + +class InnerModal(ui.Modal): + result: ModalResult + def __init__(self, config: ModalConfig, text_inputs: Sequence[TextInput]) -> None: kwargs: ModalConfigKWargs = {'title': config.title, 'timeout': config.timeout} if config.custom_id is not None: @@ -82,25 +90,82 @@ async def on_submit(self, interaction: Interaction[Client]) -> None: for child in self.children: assert isinstance(child, ui.TextInput) results.append(child.value) - self.results = tuple(results) - self.interaction = interaction + self.result = ModalResult(tuple(results), interaction) self.stop() -async def send_modal( - interaction: Interaction[Client], config: ModalConfig, text_inputs: Sequence[TextInput] -) -> tuple[tuple[str, ...], Interaction[Client]]: - """Text input modal. +class ModalController: + """Modal controller. + + you should... + - construct this class and save it to Model. Args: - interaction (Interaction): Interaction to send modal. This interaction will be consumed. config (ModalConfig): config for modal. text_inputs (Sequence[TextInput]): text inputs for modal. - - Returns: - tuple[tuple[str, ...], Interaction]: results and interaction. results length is same as text_inputs length. """ - inner_modal = InnerModal(config, text_inputs) - await interaction.response.send_modal(inner_modal) - await inner_modal.wait() - return inner_modal.results, inner_modal.interaction + + def __init__(self, config: ModalConfig, text_inputs: Sequence[TextInput]) -> None: + self.__stopped: Future[bool] = get_running_loop().create_future() + self.config = config + self.text_inputs = text_inputs + self.modals: list[tuple[InnerModal, Task[None]]] = [] + + async def _wait_modal(self, modal: InnerModal, result_future: Future[ModalResult]) -> None: + await modal.wait() + + if self.__stopped.done(): + result_future.cancel() + return + + self.__stopped.set_result(True) + result_future.set_result(modal.result) + self.__inner_cancel() + + def stop(self) -> None: + """Stop all modals. You should call this method in Model.after_invoke method.""" + self.__stopped.set_result(False) + self.__inner_cancel() + + def __inner_cancel(self) -> None: + for m, t in self.modals: + m.stop() + t.cancel() + + def is_finished(self) -> bool: + """This modal controller is finished or not. + + Returns: + bool: True if finished. + """ + return self.__stopped.done() + + async def wait(self) -> bool: + """Wait until all modals are finished. + + Returns: + bool: False if call stop method. + """ + return await self.__stopped + + async def send_modal(self, interaction: Interaction) -> ModalResult: + """Send modal. call this method in any view item callback. + + Args: + interaction (Interaction): interaction to send modal. + + Returns: + ModalResult: result of modal. + + Exceptions: + asyncio.CancelledError: + if call stop method, receive value from user interaction in other model or any other reason, + raise this error. if you want catch this error and call this in flow's callback, + you should raise any Error. not return Result. + """ + result_future: Future[ModalResult] = get_running_loop().create_future() + modal = InnerModal(self.config, self.text_inputs) + await interaction.response.send_modal(modal) + task = get_running_loop().create_task(self._wait_modal(modal, result_future)) + self.modals.append((modal, task)) + return await result_future diff --git a/discord/ext/flow/model.py b/discord/ext/flow/model.py index 9533e10..e21de34 100644 --- a/discord/ext/flow/model.py +++ b/discord/ext/flow/model.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import TYPE_CHECKING, NamedTuple +from typing import TYPE_CHECKING, NamedTuple, ParamSpec, TypeVar from discord import ButtonStyle @@ -19,7 +19,7 @@ if TYPE_CHECKING: - from collections.abc import Callable, Sequence + from collections.abc import Awaitable, Callable, Sequence from typing import Any, TypeAlias, TypedDict from discord import ( @@ -84,6 +84,9 @@ class MessageKwargs(TypedDict, total=False): | AppCommandThread | Thread ) + T = TypeVar('T') + P = ParamSpec('P') + MaybeAwaitableFunc = Callable[P, T] | Callable[P, Awaitable[T]] class Message(NamedTuple): @@ -143,7 +146,7 @@ class Button: - you should use Link instead of this if you want to send link. """ - callback: Callable[[Interaction[Client]], MaybeAwaitable[Result]] + callback: MaybeAwaitableFunc[[Interaction[Client]], Result] label: str | None = None custom_id: str | None = None disabled: bool = False @@ -170,7 +173,7 @@ class Select: - options is keyword only argument. """ - callback: Callable[[Interaction[Client], list[str]], MaybeAwaitable[Result]] + callback: MaybeAwaitableFunc[[Interaction[Client], list[str]], Result] placeholder: str | None = None custom_id: str | None = None min_values: int = 1 @@ -184,7 +187,7 @@ class Select: class UserSelect: """discord.ui.UserSelect with callback for Message.items.""" - callback: Callable[[Interaction[Client], list[User | Member]], MaybeAwaitable[Result]] + callback: MaybeAwaitableFunc[[Interaction[Client], list[User | Member]], Result] placeholder: str | None = None custom_id: str | None = None min_values: int = 1 @@ -198,7 +201,7 @@ class UserSelect: class RoleSelect: """discord.ui.RoleSelect with callback for Message.items.""" - callback: Callable[[Interaction[Client], list[Role]], MaybeAwaitable[Result]] + callback: MaybeAwaitableFunc[[Interaction[Client], list[Role]], Result] placeholder: str | None = None custom_id: str | None = None min_values: int = 1 @@ -212,7 +215,7 @@ class RoleSelect: class MentionableSelect: """discord.ui.MentionableSelect with callback for Message.items.""" - callback: Callable[[Interaction[Client], list[User | Member | Role]], MaybeAwaitable[Result]] + callback: MaybeAwaitableFunc[[Interaction[Client], list[User | Member | Role]], Result] placeholder: str | None = None custom_id: str | None = None min_values: int = 1 @@ -226,7 +229,7 @@ class MentionableSelect: class ChannelSelect: """discord.ui.ChannelSelect with callback for Message.items.""" - callback: Callable[[Interaction[Client], list[AppCommandChannel | AppCommandThread]], MaybeAwaitable[Result]] + callback: MaybeAwaitableFunc[[Interaction[Client], list[AppCommandChannel | AppCommandThread]], Result] placeholder: str | None = None custom_id: str | None = None min_values: int = 1 diff --git a/discord/ext/flow/pages.py b/discord/ext/flow/pages.py index 68053ef..265d257 100644 --- a/discord/ext/flow/pages.py +++ b/discord/ext/flow/pages.py @@ -4,9 +4,9 @@ from discord.utils import maybe_coroutine -from .modal import ModalConfig, TextInput, send_modal -from .model import Button, Message -from .result import Result +from .modal import ModalConfig, ModalController, TextInput +from .model import Button, ItemType, Message +from .result import Result, _ResultTypeEnum if TYPE_CHECKING: from collections.abc import Awaitable, Callable, Sequence @@ -63,6 +63,10 @@ def __init__( div, mod = divmod(len(values), per_page) self.max_page = div + (mod != 0) self.row = row + self.modal_controller = ModalController( + ModalConfig(title='Page Number'), + (TextInput(label='page number', placeholder=f'1 ~ {self.max_page}', required=True),), + ) async def _message(self, *, edit_original: bool = False) -> Message: msg = await maybe_coroutine( @@ -71,7 +75,13 @@ async def _message(self, *, edit_original: bool = False) -> Message: self.current_page, self.max_page, ) - items = () if msg.items is None else tuple(msg.items) + items: list[ItemType] = [] + if msg.items is not None: + for item in msg.items: + if hasattr(item, 'callback'): + # type safe, because we are sure that item has callback attribute and this is not change type hint. + item.callback = self._finalize_modal(item.callback) # type: ignore[reportUnknownMemberType, reportArgumentType, reportAttributeAccessIssue] + items.append(item) if len(items) > 20: raise ValueError('Message.items must be less than 20') @@ -89,7 +99,25 @@ async def _message(self, *, edit_original: bool = False) -> Message: Button(emoji=LAST_EMOJI, row=self.row, disabled=is_final_page, callback=self._go_to_last_page), ) - return msg._replace(items=items + control_items, edit_original=edit_original or msg.edit_original) + return msg._replace(items=tuple(items) + control_items, edit_original=edit_original or msg.edit_original) + + def _finalize_modal(self, callback: MaybeAwaitableFunc[P, Result]) -> Callable[P, Awaitable[Result]]: + async def finalize(*args: P.args, **kwargs: P.kwargs) -> Result: + result = await maybe_coroutine(callback, *args, **kwargs) + if ( + ( + result._type in (_ResultTypeEnum.MODEL, _ResultTypeEnum.FINISH) + ) or ( + result._type == _ResultTypeEnum.MESSAGE + and result._message is not None + and (not result._message.items or result._message.disable_items) + ) + ): # fmt: skip + # it is stop view and pagination is finished. so, we can stop all modals. + self.modal_controller.stop() + return result + + return finalize def _set_page_number(self, page_number: int) -> None: if 0 <= page_number < self.max_page: @@ -104,15 +132,11 @@ async def _go_to_previous_page(self, _: Interaction[Client]) -> Result: return Result.send_message(message=await self._message(edit_original=True)) async def _go_to_page(self, interaction: Interaction[Client]) -> Result: - texts, interaction = await send_modal( - interaction, - ModalConfig(title='Page Number'), - (TextInput(label='page number', placeholder=f'1 ~ {self.max_page}', required=True),), - ) - assert len(texts) >= 1 - assert texts[0].isdigit() - self._set_page_number(int(texts[0]) - 1) - return Result.send_message(message=await self._message(edit_original=True), interaction=interaction) + result = await self.modal_controller.send_modal(interaction) + assert len(result.texts) >= 1 + assert result.texts[0].isdigit() + self._set_page_number(int(result.texts[0]) - 1) + return Result.send_message(message=await self._message(edit_original=True), interaction=result.interaction) async def _go_to_next_page(self, _: Interaction[Client]) -> Result: self._set_page_number(self.current_page + 1) diff --git a/discord/ext/flow/view.py b/discord/ext/flow/view.py index 260574c..8777696 100644 --- a/discord/ext/flow/view.py +++ b/discord/ext/flow/view.py @@ -1,5 +1,7 @@ from __future__ import annotations +from asyncio import CancelledError +from contextlib import suppress from typing import TYPE_CHECKING from discord import Client, Interaction, ui @@ -31,7 +33,8 @@ def __init__(self, config: Button) -> None: self.config = config async def callback(self, interaction: Interaction[Client]) -> None: - await self.view.set_result(await maybe_coroutine(self.config.callback, interaction), interaction) + with suppress(CancelledError): + await self.view.set_result(await maybe_coroutine(self.config.callback, interaction), interaction) class _Link(ui.Button['_View']): @@ -55,7 +58,10 @@ def __init__(self, config: Select) -> None: self.config = config async def callback(self, interaction: Interaction[Client]) -> None: - await self.view.set_result(await maybe_coroutine(self.config.callback, interaction, self.values), interaction) + with suppress(CancelledError): + await self.view.set_result( + await maybe_coroutine(self.config.callback, interaction, self.values), interaction + ) class _UserSelect(ui.UserSelect['_View']): @@ -74,7 +80,10 @@ def __init__(self, config: UserSelect) -> None: self.config = config async def callback(self, interaction: Interaction[Client]) -> None: - await self.view.set_result(await maybe_coroutine(self.config.callback, interaction, self.values), interaction) + with suppress(CancelledError): + await self.view.set_result( + await maybe_coroutine(self.config.callback, interaction, self.values), interaction + ) class _RoleSelect(ui.RoleSelect['_View']): @@ -93,7 +102,10 @@ def __init__(self, config: RoleSelect) -> None: self.config = config async def callback(self, interaction: Interaction[Client]) -> None: - await self.view.set_result(await maybe_coroutine(self.config.callback, interaction, self.values), interaction) + with suppress(CancelledError): + await self.view.set_result( + await maybe_coroutine(self.config.callback, interaction, self.values), interaction + ) class _MentionableSelect(ui.MentionableSelect['_View']): @@ -112,7 +124,10 @@ def __init__(self, config: MentionableSelect) -> None: self.config = config async def callback(self, interaction: Interaction[Client]) -> None: - await self.view.set_result(await maybe_coroutine(self.config.callback, interaction, self.values), interaction) + with suppress(CancelledError): + await self.view.set_result( + await maybe_coroutine(self.config.callback, interaction, self.values), interaction + ) class _ChannelSelect(ui.ChannelSelect['_View']): @@ -131,7 +146,10 @@ def __init__(self, config: ChannelSelect) -> None: self.config = config async def callback(self, interaction: Interaction[Client]) -> None: - await self.view.set_result(await maybe_coroutine(self.config.callback, interaction, self.values), interaction) + with suppress(CancelledError): + await self.view.set_result( + await maybe_coroutine(self.config.callback, interaction, self.values), interaction + ) class _View(ui.View): diff --git a/example/modal.py b/example/modal.py index 464e9b5..f557bbd 100644 --- a/example/modal.py +++ b/example/modal.py @@ -4,12 +4,27 @@ from discord import Client, Embed, Intents, Interaction from discord.app_commands import CommandTree -from discord.ext.flow import Button, Controller, Message, ModalConfig, ModelBase, Result, TextInput, send_modal +from discord.ext.flow import Button, Controller, Message, ModalConfig, ModalController, ModelBase, Result, TextInput class EmbedModel(ModelBase): def __init__(self, title: str) -> None: self.embed = Embed(title=title, description='') + self.title_modal = ModalController( + ModalConfig(title='Edit Title'), + (TextInput(label='title', default=self.embed.title),), + ) + self.description_modal = ModalController( + ModalConfig(title='Edit Description'), + (TextInput(label='description', default=self.embed.description),), + ) + self.title_and_description_modal = ModalController( + ModalConfig(title='Edit Title and Description'), + ( + TextInput(label='title', default=self.embed.title), + TextInput(label='description', default=self.embed.description), + ), + ) def message(self) -> Message: return Message( @@ -25,46 +40,36 @@ def message(self) -> Message: ephemeral=True, ) + async def after_invoke(self) -> None: + self.title_modal.stop() + self.description_modal.stop() + self.title_and_description_modal.stop() + def edit_title_button(self) -> Button: async def inner(interaction: Interaction[Client]) -> Result: - texts, interaction = await send_modal( - interaction, - ModalConfig(title='Edit Title'), - (TextInput(label='title', default=self.embed.title),), - ) - assert len(texts) >= 1 - self.embed.title = texts[0] - return Result.send_message(message=self.message(), interaction=interaction) + result = await self.title_modal.send_modal(interaction) + assert len(result.texts) >= 1 + self.embed.title = result.texts[0] + return Result.send_message(message=self.message(), interaction=result.interaction) return Button(label='edit title', callback=inner) def edit_description_button(self) -> Button: async def inner(interaction: Interaction[Client]) -> Result: - texts, interaction = await send_modal( - interaction, - ModalConfig(title='Edit Description'), - (TextInput(label='description', default=self.embed.description),), - ) - assert len(texts) >= 1 - self.embed.description = texts[0] - return Result.send_message(message=self.message(), interaction=interaction) + result = await self.description_modal.send_modal(interaction) + assert len(result.texts) >= 1 + self.embed.description = result.texts[0] + return Result.send_message(message=self.message(), interaction=result.interaction) return Button(label='edit description', callback=inner) def edit_title_and_description_button(self) -> Button: async def inner(interaction: Interaction[Client]) -> Result: - texts, interaction = await send_modal( - interaction, - ModalConfig(title='Edit Title and Description'), - ( - TextInput(label='title', default=self.embed.title), - TextInput(label='description', default=self.embed.description), - ), - ) - assert len(texts) >= 2 - self.embed.title = texts[0] - self.embed.description = texts[1] - return Result.send_message(message=self.message(), interaction=interaction) + result = await self.title_and_description_modal.send_modal(interaction) + assert len(result.texts) >= 2 + self.embed.title = result.texts[0] + self.embed.description = result.texts[1] + return Result.send_message(message=self.message(), interaction=result.interaction) return Button(label='edit title and description', callback=inner)