Skip to content

Commit

Permalink
update: refactor Modal feature
Browse files Browse the repository at this point in the history
  • Loading branch information
hawk-tomy committed Sep 10, 2024
1 parent b591e61 commit 47b2baa
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 76 deletions.
3 changes: 2 additions & 1 deletion discord/ext/flow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
'Controller',
'TextInput',
'ModalConfig',
'send_modal',
'ModalResult',
'ModalController',
'Paginator',
'paginator',
'Result',
Expand Down
101 changes: 83 additions & 18 deletions discord/ext/flow/modal.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from __future__ import annotations

from asyncio import Future, Task, get_running_loop
from dataclasses import dataclass
from typing import TYPE_CHECKING, TypedDict
from typing import TYPE_CHECKING, NamedTuple, TypedDict

from discord import Client, Interaction, TextStyle, ui

if TYPE_CHECKING:
from collections.abc import Sequence

__all__ = ('TextInput', 'ModalConfig', 'send_modal')

__all__ = ('TextInput', 'ModalConfig', 'ModalResult', 'ModalController')


@dataclass
Expand Down Expand Up @@ -53,10 +55,16 @@ class ModalConfigKWargs(TypedDict, total=False):
custom_id: str


class InnerModal(ui.Modal):
results: tuple[str, ...]
class ModalResult(NamedTuple):
"""Result of modal. length of texts is same as length of text_inputs in ModalConfig."""

texts: tuple[str, ...]
interaction: Interaction[Client]


class InnerModal(ui.Modal):
result: ModalResult

def __init__(self, config: ModalConfig, text_inputs: Sequence[TextInput]) -> None:
kwargs: ModalConfigKWargs = {'title': config.title, 'timeout': config.timeout}
if config.custom_id is not None:
Expand All @@ -82,25 +90,82 @@ async def on_submit(self, interaction: Interaction[Client]) -> None:
for child in self.children:
assert isinstance(child, ui.TextInput)
results.append(child.value)
self.results = tuple(results)
self.interaction = interaction
self.result = ModalResult(tuple(results), interaction)
self.stop()


async def send_modal(
interaction: Interaction[Client], config: ModalConfig, text_inputs: Sequence[TextInput]
) -> tuple[tuple[str, ...], Interaction[Client]]:
"""Text input modal.
class ModalController:
"""Modal controller.
you should...
- construct this class and save it to Model.
Args:
interaction (Interaction): Interaction to send modal. This interaction will be consumed.
config (ModalConfig): config for modal.
text_inputs (Sequence[TextInput]): text inputs for modal.
Returns:
tuple[tuple[str, ...], Interaction]: results and interaction. results length is same as text_inputs length.
"""
inner_modal = InnerModal(config, text_inputs)
await interaction.response.send_modal(inner_modal)
await inner_modal.wait()
return inner_modal.results, inner_modal.interaction

def __init__(self, config: ModalConfig, text_inputs: Sequence[TextInput]) -> None:
self.__stopped: Future[bool] = get_running_loop().create_future()
self.config = config
self.text_inputs = text_inputs
self.modals: list[tuple[InnerModal, Task[None]]] = []

async def _wait_modal(self, modal: InnerModal, result_future: Future[ModalResult]) -> None:
await modal.wait()

if self.__stopped.done():
result_future.cancel()
return

self.__stopped.set_result(True)
result_future.set_result(modal.result)
self.__inner_cancel()

def stop(self) -> None:
"""Stop all modals. You should call this method in Model.after_invoke method."""
self.__stopped.set_result(False)
self.__inner_cancel()

def __inner_cancel(self) -> None:
for m, t in self.modals:
m.stop()
t.cancel()

def is_finished(self) -> bool:
"""This modal controller is finished or not.
Returns:
bool: True if finished.
"""
return self.__stopped.done()

async def wait(self) -> bool:
"""Wait until all modals are finished.
Returns:
bool: False if call stop method.
"""
return await self.__stopped

async def send_modal(self, interaction: Interaction) -> ModalResult:
"""Send modal. call this method in any view item callback.
Args:
interaction (Interaction): interaction to send modal.
Returns:
ModalResult: result of modal.
Exceptions:
asyncio.CancelledError:
if call stop method, receive value from user interaction in other model or any other reason,
raise this error. if you want catch this error and call this in flow's callback,
you should raise any Error. not return Result.
"""
result_future: Future[ModalResult] = get_running_loop().create_future()
modal = InnerModal(self.config, self.text_inputs)
await interaction.response.send_modal(modal)
task = get_running_loop().create_task(self._wait_modal(modal, result_future))
self.modals.append((modal, task))
return await result_future
19 changes: 11 additions & 8 deletions discord/ext/flow/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, NamedTuple
from typing import TYPE_CHECKING, NamedTuple, ParamSpec, TypeVar

from discord import ButtonStyle

Expand All @@ -19,7 +19,7 @@


if TYPE_CHECKING:
from collections.abc import Callable, Sequence
from collections.abc import Awaitable, Callable, Sequence
from typing import Any, TypeAlias, TypedDict

from discord import (
Expand Down Expand Up @@ -84,6 +84,9 @@ class MessageKwargs(TypedDict, total=False):
| AppCommandThread
| Thread
)
T = TypeVar('T')
P = ParamSpec('P')
MaybeAwaitableFunc = Callable[P, T] | Callable[P, Awaitable[T]]


class Message(NamedTuple):
Expand Down Expand Up @@ -143,7 +146,7 @@ class Button:
- you should use Link instead of this if you want to send link.
"""

callback: Callable[[Interaction[Client]], MaybeAwaitable[Result]]
callback: MaybeAwaitableFunc[[Interaction[Client]], Result]
label: str | None = None
custom_id: str | None = None
disabled: bool = False
Expand All @@ -170,7 +173,7 @@ class Select:
- options is keyword only argument.
"""

callback: Callable[[Interaction[Client], list[str]], MaybeAwaitable[Result]]
callback: MaybeAwaitableFunc[[Interaction[Client], list[str]], Result]
placeholder: str | None = None
custom_id: str | None = None
min_values: int = 1
Expand All @@ -184,7 +187,7 @@ class Select:
class UserSelect:
"""discord.ui.UserSelect with callback for Message.items."""

callback: Callable[[Interaction[Client], list[User | Member]], MaybeAwaitable[Result]]
callback: MaybeAwaitableFunc[[Interaction[Client], list[User | Member]], Result]
placeholder: str | None = None
custom_id: str | None = None
min_values: int = 1
Expand All @@ -198,7 +201,7 @@ class UserSelect:
class RoleSelect:
"""discord.ui.RoleSelect with callback for Message.items."""

callback: Callable[[Interaction[Client], list[Role]], MaybeAwaitable[Result]]
callback: MaybeAwaitableFunc[[Interaction[Client], list[Role]], Result]
placeholder: str | None = None
custom_id: str | None = None
min_values: int = 1
Expand All @@ -212,7 +215,7 @@ class RoleSelect:
class MentionableSelect:
"""discord.ui.MentionableSelect with callback for Message.items."""

callback: Callable[[Interaction[Client], list[User | Member | Role]], MaybeAwaitable[Result]]
callback: MaybeAwaitableFunc[[Interaction[Client], list[User | Member | Role]], Result]
placeholder: str | None = None
custom_id: str | None = None
min_values: int = 1
Expand All @@ -226,7 +229,7 @@ class MentionableSelect:
class ChannelSelect:
"""discord.ui.ChannelSelect with callback for Message.items."""

callback: Callable[[Interaction[Client], list[AppCommandChannel | AppCommandThread]], MaybeAwaitable[Result]]
callback: MaybeAwaitableFunc[[Interaction[Client], list[AppCommandChannel | AppCommandThread]], Result]
placeholder: str | None = None
custom_id: str | None = None
min_values: int = 1
Expand Down
52 changes: 38 additions & 14 deletions discord/ext/flow/pages.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

from discord.utils import maybe_coroutine

from .modal import ModalConfig, TextInput, send_modal
from .model import Button, Message
from .result import Result
from .modal import ModalConfig, ModalController, TextInput
from .model import Button, ItemType, Message
from .result import Result, _ResultTypeEnum

if TYPE_CHECKING:
from collections.abc import Awaitable, Callable, Sequence
Expand Down Expand Up @@ -63,6 +63,10 @@ def __init__(
div, mod = divmod(len(values), per_page)
self.max_page = div + (mod != 0)
self.row = row
self.modal_controller = ModalController(
ModalConfig(title='Page Number'),
(TextInput(label='page number', placeholder=f'1 ~ {self.max_page}', required=True),),
)

async def _message(self, *, edit_original: bool = False) -> Message:
msg = await maybe_coroutine(
Expand All @@ -71,7 +75,13 @@ async def _message(self, *, edit_original: bool = False) -> Message:
self.current_page,
self.max_page,
)
items = () if msg.items is None else tuple(msg.items)
items: list[ItemType] = []
if msg.items is not None:
for item in msg.items:
if hasattr(item, 'callback'):
# type safe, because we are sure that item has callback attribute and this is not change type hint.
item.callback = self._finalize_modal(item.callback) # type: ignore[reportUnknownMemberType, reportArgumentType, reportAttributeAccessIssue]
items.append(item)
if len(items) > 20:
raise ValueError('Message.items must be less than 20')

Expand All @@ -89,7 +99,25 @@ async def _message(self, *, edit_original: bool = False) -> Message:
Button(emoji=LAST_EMOJI, row=self.row, disabled=is_final_page, callback=self._go_to_last_page),
)

return msg._replace(items=items + control_items, edit_original=edit_original or msg.edit_original)
return msg._replace(items=tuple(items) + control_items, edit_original=edit_original or msg.edit_original)

def _finalize_modal(self, callback: MaybeAwaitableFunc[P, Result]) -> Callable[P, Awaitable[Result]]:
async def finalize(*args: P.args, **kwargs: P.kwargs) -> Result:
result = await maybe_coroutine(callback, *args, **kwargs)
if (
(
result._type in (_ResultTypeEnum.MODEL, _ResultTypeEnum.FINISH)
) or (
result._type == _ResultTypeEnum.MESSAGE
and result._message is not None
and (not result._message.items or result._message.disable_items)
)
): # fmt: skip
# it is stop view and pagination is finished. so, we can stop all modals.
self.modal_controller.stop()
return result

return finalize

def _set_page_number(self, page_number: int) -> None:
if 0 <= page_number < self.max_page:
Expand All @@ -104,15 +132,11 @@ async def _go_to_previous_page(self, _: Interaction[Client]) -> Result:
return Result.send_message(message=await self._message(edit_original=True))

async def _go_to_page(self, interaction: Interaction[Client]) -> Result:
texts, interaction = await send_modal(
interaction,
ModalConfig(title='Page Number'),
(TextInput(label='page number', placeholder=f'1 ~ {self.max_page}', required=True),),
)
assert len(texts) >= 1
assert texts[0].isdigit()
self._set_page_number(int(texts[0]) - 1)
return Result.send_message(message=await self._message(edit_original=True), interaction=interaction)
result = await self.modal_controller.send_modal(interaction)
assert len(result.texts) >= 1
assert result.texts[0].isdigit()
self._set_page_number(int(result.texts[0]) - 1)
return Result.send_message(message=await self._message(edit_original=True), interaction=result.interaction)

async def _go_to_next_page(self, _: Interaction[Client]) -> Result:
self._set_page_number(self.current_page + 1)
Expand Down
30 changes: 24 additions & 6 deletions discord/ext/flow/view.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

from asyncio import CancelledError
from contextlib import suppress
from typing import TYPE_CHECKING

from discord import Client, Interaction, ui
Expand Down Expand Up @@ -31,7 +33,8 @@ def __init__(self, config: Button) -> None:
self.config = config

async def callback(self, interaction: Interaction[Client]) -> None:
await self.view.set_result(await maybe_coroutine(self.config.callback, interaction), interaction)
with suppress(CancelledError):
await self.view.set_result(await maybe_coroutine(self.config.callback, interaction), interaction)


class _Link(ui.Button['_View']):
Expand All @@ -55,7 +58,10 @@ def __init__(self, config: Select) -> None:
self.config = config

async def callback(self, interaction: Interaction[Client]) -> None:
await self.view.set_result(await maybe_coroutine(self.config.callback, interaction, self.values), interaction)
with suppress(CancelledError):
await self.view.set_result(
await maybe_coroutine(self.config.callback, interaction, self.values), interaction
)


class _UserSelect(ui.UserSelect['_View']):
Expand All @@ -74,7 +80,10 @@ def __init__(self, config: UserSelect) -> None:
self.config = config

async def callback(self, interaction: Interaction[Client]) -> None:
await self.view.set_result(await maybe_coroutine(self.config.callback, interaction, self.values), interaction)
with suppress(CancelledError):
await self.view.set_result(
await maybe_coroutine(self.config.callback, interaction, self.values), interaction
)


class _RoleSelect(ui.RoleSelect['_View']):
Expand All @@ -93,7 +102,10 @@ def __init__(self, config: RoleSelect) -> None:
self.config = config

async def callback(self, interaction: Interaction[Client]) -> None:
await self.view.set_result(await maybe_coroutine(self.config.callback, interaction, self.values), interaction)
with suppress(CancelledError):
await self.view.set_result(
await maybe_coroutine(self.config.callback, interaction, self.values), interaction
)


class _MentionableSelect(ui.MentionableSelect['_View']):
Expand All @@ -112,7 +124,10 @@ def __init__(self, config: MentionableSelect) -> None:
self.config = config

async def callback(self, interaction: Interaction[Client]) -> None:
await self.view.set_result(await maybe_coroutine(self.config.callback, interaction, self.values), interaction)
with suppress(CancelledError):
await self.view.set_result(
await maybe_coroutine(self.config.callback, interaction, self.values), interaction
)


class _ChannelSelect(ui.ChannelSelect['_View']):
Expand All @@ -131,7 +146,10 @@ def __init__(self, config: ChannelSelect) -> None:
self.config = config

async def callback(self, interaction: Interaction[Client]) -> None:
await self.view.set_result(await maybe_coroutine(self.config.callback, interaction, self.values), interaction)
with suppress(CancelledError):
await self.view.set_result(
await maybe_coroutine(self.config.callback, interaction, self.values), interaction
)


class _View(ui.View):
Expand Down
Loading

0 comments on commit 47b2baa

Please sign in to comment.