Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

start using any io for structured concurency #1065

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ For more installation options (eg: `aws`, `gcp`, `srv` ...) you can look in the
## Example

```python
import asyncio
import anyio
from typing import Optional

from motor.motor_asyncio import AsyncIOMotorClient
Expand Down Expand Up @@ -94,7 +94,7 @@ async def example():


if __name__ == "__main__":
asyncio.run(example())
anyio.run(example)
```

## Links
Expand Down
4 changes: 2 additions & 2 deletions beanie/executors/migrate.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import asyncio
import logging
import os
import shutil
from datetime import datetime
from pathlib import Path
from typing import Any

import anyio
import click
import toml

Expand Down Expand Up @@ -197,7 +197,7 @@ def migrate(
settings_kwargs["use_transaction"] = use_transaction
settings = MigrationSettings(**settings_kwargs)

asyncio.run(run_migrate(settings))
anyio.run(run_migrate, settings)


@migrations.command()
Expand Down
79 changes: 42 additions & 37 deletions beanie/migrations/controllers/iterative.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
from functools import partial
from inspect import isclass, signature
from typing import Any, List, Optional, Type, Union

from anyio import create_task_group

from beanie.migrations.controllers.base import BaseMigrationController
from beanie.migrations.utils import update_dict
from beanie.odm.documents import Document
Expand Down Expand Up @@ -92,43 +94,46 @@ def models(self) -> List[Type[Document]]:

async def run(self, session):
output_documents = []
all_migration_ops = []
async for input_document in self.input_document_model.find_all(
session=session
):
output = DummyOutput()
function_kwargs = {
"input_document": input_document,
"output_document": output,
}
if "self" in self.function_signature.parameters:
function_kwargs["self"] = None
await self.function(**function_kwargs)
output_dict = (
input_document.dict()
if not IS_PYDANTIC_V2
else input_document.model_dump()
)
update_dict(output_dict, output.dict())
output_document = parse_model(
self.output_document_model, output_dict
)
output_documents.append(output_document)

if len(output_documents) == self.batch_size:
all_migration_ops.append(
self.output_document_model.replace_many(
documents=output_documents, session=session
)
async with create_task_group() as tg:
async for input_document in self.input_document_model.find_all(
session=session
):
output = DummyOutput()
function_kwargs = {
"input_document": input_document,
"output_document": output,
}
if "self" in self.function_signature.parameters:
function_kwargs["self"] = None
await self.function(**function_kwargs)
output_dict = (
input_document.dict()
if not IS_PYDANTIC_V2
else input_document.model_dump()
)
output_documents = []

if output_documents:
all_migration_ops.append(
self.output_document_model.replace_many(
documents=output_documents, session=session
update_dict(output_dict, output.dict())
output_document = parse_model(
self.output_document_model, output_dict
)
output_documents.append(output_document)

if len(output_documents) == self.batch_size:
tg.start_soon(
partial(
self.output_document_model.replace_many,
documents=output_documents,
session=session,
)
)
output_documents = []

if output_documents:
tg.start_soon(
partial(
self.output_document_model.replace_many,
documents=output_documents,
session=session,
)
)
)
await asyncio.gather(*all_migration_ops)

return IterativeMigration
36 changes: 20 additions & 16 deletions beanie/odm/actions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import inspect
from enum import Enum
from functools import wraps
from functools import partial, wraps
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -15,6 +14,7 @@
Union,
)

from anyio import create_task_group, to_thread
from typing_extensions import ParamSpec

if TYPE_CHECKING:
Expand Down Expand Up @@ -76,7 +76,7 @@ def add_action(
:param document_class: document class
:param event_types: List[EventTypes]
:param action_direction: ActionDirections - before or after
:param funct: Callable - function
:param funct: Callable - function must be either thread safe or async safe
"""
if cls._actions.get(document_class) is None:
cls._actions[document_class] = {
Expand Down Expand Up @@ -130,19 +130,23 @@ async def run_actions(
actions_list = cls.get_action_list(
document_class, event_type, action_direction
)
coros = []
for action in actions_list:
if action.__name__ in exclude:
continue

if inspect.iscoroutinefunction(action):
coros.append(action(instance))
elif inspect.isfunction(action):
action(instance)
await asyncio.gather(*coros)


# `Any` because there is arbitrary attribute assignment on this type
async with create_task_group() as tg:
for action in actions_list:
if action.__name__ in exclude:
continue
if inspect.iscoroutinefunction(action):
tg.start_soon(action, instance)
elif inspect.isfunction(action):
tg.start_soon(
partial(
to_thread.run_sync,
partial(action, instance),
abandon_on_cancel=True,
)
)


# `Any` because there is an arbitrary attribute assignment on this type
F = TypeVar("F", bound=Any)


Expand Down
88 changes: 45 additions & 43 deletions beanie/odm/documents.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import warnings
from datetime import datetime, timezone
from enum import Enum
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -20,6 +20,7 @@
)
from uuid import UUID, uuid4

from anyio import create_task_group
from bson import DBRef, ObjectId
from lazy_model import LazyModel
from motor.motor_asyncio import AsyncIOMotorClientSession
Expand Down Expand Up @@ -127,6 +128,10 @@
DocumentProjectionType = TypeVar("DocumentProjectionType", bound=BaseModel)


def only_documents(objets: List[Any]):
return filter(lambda obj: isinstance(obj, Document), objets)


def json_schema_extra(schema: Dict[str, Any], model: Type["Document"]) -> None:
# remove excluded fields from the json schema
properties = schema.get("properties")
Expand Down Expand Up @@ -353,16 +358,16 @@ async def insert(
LinkTypes.OPTIONAL_LIST,
]:
if isinstance(value, List):
await asyncio.gather(
*[
obj.save(
link_rule=WriteRules.WRITE,
session=session,
async with create_task_group() as tg:
for obj in only_documents(value):
tg.start_soon(
partial(
obj.save,
link_rule=WriteRules.WRITE,
session=session,
)
)
for obj in value
if isinstance(obj, Document)
]
)

result = await self.get_motor_collection().insert_one(
get_dict(
self, to_db=True, keep_nulls=self.get_settings().keep_nulls
Expand Down Expand Up @@ -513,18 +518,17 @@ async def replace(
LinkTypes.OPTIONAL_BACK_LIST,
]:
if isinstance(value, List):
await asyncio.gather(
*[
obj.replace(
link_rule=link_rule,
bulk_writer=bulk_writer,
ignore_revision=ignore_revision,
session=session,
async with create_task_group() as tg:
for obj in only_documents(value):
tg.start_soon(
partial(
obj.replace,
link_rule=link_rule,
bulk_writer=bulk_writer,
ignore_revision=ignore_revision,
session=session,
)
)
for obj in value
if isinstance(obj, Document)
]
)

use_revision_id = self.get_settings().use_revision
find_query: Dict[str, Any] = {"_id": self.id}
Expand Down Expand Up @@ -586,15 +590,15 @@ async def save(
LinkTypes.OPTIONAL_BACK_LIST,
]:
if isinstance(value, List):
await asyncio.gather(
*[
obj.save(
link_rule=link_rule, session=session
async with create_task_group() as tg:
for obj in only_documents(value):
tg.start_soon(
partial(
obj.save,
link_rule=link_rule,
session=session,
)
)
for obj in value
if isinstance(obj, Document)
]
)

if self.get_settings().keep_nulls is False:
return await self.update(
Expand Down Expand Up @@ -911,16 +915,15 @@ async def delete(
LinkTypes.OPTIONAL_BACK_LIST,
]:
if isinstance(value, List):
await asyncio.gather(
*[
obj.delete(
link_rule=DeleteRules.DELETE_LINKS,
**pymongo_kwargs,
async with create_task_group() as tg:
for obj in only_documents(value):
tg.start_soon(
partial(
obj.delete,
link_rule=DeleteRules.DELETE_LINKS,
**pymongo_kwargs,
)
)
for obj in value
if isinstance(obj, Document)
]
)

return await self.find_one({"_id": self.id}).delete(
session=session, bulk_writer=bulk_writer, **pymongo_kwargs
Expand Down Expand Up @@ -1182,12 +1185,11 @@ async def fetch_link(self, field: Union[str, Any]):
setattr(self, field, values)

async def fetch_all_links(self):
coros = []
link_fields = self.get_link_fields()
if link_fields is not None:
for ref in link_fields.values():
coros.append(self.fetch_link(ref.field_name)) # TODO lists
await asyncio.gather(*coros)
if link_fields is not None and len(link_fields.values()) > 0:
async with create_task_group() as tg:
for ref in link_fields.values():
tg.start_soon(self.fetch_link, ref.field_name)

@classmethod
def get_link_fields(cls) -> Optional[Dict[str, LinkInfo]]:
Expand Down
10 changes: 5 additions & 5 deletions beanie/odm/fields.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import asyncio
from collections import OrderedDict
from dataclasses import dataclass
from enum import Enum
Expand All @@ -18,6 +17,7 @@
)
from typing import OrderedDict as OrderedDictType

from anyio import create_task_group
from bson import DBRef, ObjectId
from bson.errors import InvalidId
from pydantic import BaseModel
Expand Down Expand Up @@ -363,10 +363,10 @@ def repack_links(

@classmethod
async def fetch_many(cls, links: List[Link]):
coros = []
for link in links:
coros.append(link.fetch())
return await asyncio.gather(*coros)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method used to return here previously. I don't see how are the "results" being returned now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch. will add test to his this is breaking change

if links:
async with create_task_group() as tg:
for link in links:
tg.start_soon(link.fetch)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I'm not mistaken (since I don't know anyio), here we should use something else apart from tg.start_soon() since it can't be awaited...
Alternatively, if somehow these results can be collected afterwards then we await all those tasks and return all the collected results.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E.g. FastAPI does something like this, lines 856-867:
https://github.com/fastapi/fastapi/blob/bffb4115a9b63127948cc5e1aa14d73940734f75/fastapi/dependencies/utils.py#L856
It calls tg.start_soon(), but then in the method it calls it stores those results to a variable declared in the "upper" scope. Not sure if this is feasible to do here, since we call some "internal" Link class method. Perhaps feasible, but with some more refactoring required...


if IS_PYDANTIC_V2:

Expand Down
Loading
Loading