Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix optional embedded model #344

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 44 additions & 27 deletions odmantic/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,28 +699,10 @@ def __doc(
) -> Dict[str, Any]:
doc: Dict[str, Any] = {}
for field_name, field in model.__odm_fields__.items():
if include is not None and field_name not in include:
continue
if isinstance(field, ODMReference):
doc[field.key_name] = raw_doc[field_name][field.model.__primary_field__]
elif isinstance(field, ODMEmbedded):
doc[field.key_name] = self.__doc(raw_doc[field_name], field.model, None)
elif isinstance(field, ODMEmbeddedGeneric):
if field.generic_origin is dict:
doc[field.key_name] = {
item_key: self.__doc(item_value, field.model)
for item_key, item_value in raw_doc[field_name].items()
}
else:
doc[field.key_name] = [
self.__doc(item, field.model) for item in raw_doc[field_name]
]
elif field_name in model.__bson_serialized_fields__:
doc[field.key_name] = model.__fields__[field_name].type_.__bson__(
raw_doc[field_name]
if include is None or field_name in include:
doc[field.key_name] = self.__doc_value(
raw_doc[field_name], field_name, field, model
)
else:
doc[field.key_name] = raw_doc[field_name]

if model.Config.extra == "allow":
extras = set(raw_doc.keys()) - set(model.__odm_fields__.keys())
Expand All @@ -731,6 +713,34 @@ def __doc(
doc[extra] = bson_serialization_method(raw_doc[extra])
return doc

def __doc_value(
self,
raw_value: Any,
field_name: str,
field: ODMBaseField,
model: Type["_BaseODMModel"],
) -> Any:
if isinstance(field, ODMReference):
return raw_value[field.model.__primary_field__]
if isinstance(field, ODMEmbedded):
return self.__doc(raw_value, field.model)
if isinstance(field, ODMEmbeddedGeneric):
if field.generic_origin is dict:
return {
item_key: self.__doc(item_value, field.model)
for item_key, item_value in raw_value.items()
}
if field.generic_origin in (list, tuple, set):
return [self.__doc(item, field.model) for item in raw_value]
if field.generic_origin is Union: # actually Optional
if raw_value is not None:
return self.__doc(raw_value, field.model)
else:
return raw_value
if field_name in model.__bson_serialized_fields__:
return model.__fields__[field_name].type_.__bson__(raw_value)
return raw_value

def doc(self, include: Optional["AbstractSetIntStr"] = None) -> Dict[str, Any]:
"""Generate a document representation of the instance (as a dictionary).

Expand Down Expand Up @@ -817,13 +827,10 @@ def _parse_doc_to_obj( # noqa C901 # TODO: refactor document parsing
)
obj[field_name] = value
elif isinstance(field, ODMEmbeddedGeneric):
value = Undefined
raw_value = raw_doc.get(field.key_name, Undefined)
if raw_value is not Undefined:
if isinstance(raw_value, list) and (
field.generic_origin is list
or field.generic_origin is tuple
or field.generic_origin is set
if field.generic_origin in (list, tuple, set) and isinstance(
raw_value, list
):
value = []
for i, item in enumerate(raw_value):
Expand All @@ -835,7 +842,7 @@ def _parse_doc_to_obj( # noqa C901 # TODO: refactor document parsing
else:
value.append(item)
obj[field_name] = value
elif isinstance(raw_value, dict) and field.generic_origin is dict:
elif field.generic_origin is dict and isinstance(raw_value, dict):
value = {}
for item_key, item_value in raw_value.items():
sub_errors, item_value = field.model._parse_doc_to_obj(
Expand All @@ -847,6 +854,15 @@ def _parse_doc_to_obj( # noqa C901 # TODO: refactor document parsing
else:
value[item_key] = item_value
obj[field_name] = value
elif field.generic_origin is Union: # actually Optional
if raw_value is not None:
sub_errors, value = field.model._parse_doc_to_obj(
raw_value, base_loc=base_loc + (field_name,)
)
errors.extend(sub_errors)
obj[field_name] = value
else:
obj[field_name] = None
else:
errors.append(
ErrorWrapper(
Expand All @@ -855,6 +871,7 @@ def _parse_doc_to_obj( # noqa C901 # TODO: refactor document parsing
)
)
else:
value = Undefined
if not field.is_required_in_doc():
value = field.get_default_importing_value()
if value is Undefined:
Expand Down
40 changes: 39 additions & 1 deletion tests/integration/test_embedded_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Tuple
from typing import Any, Dict, List, Tuple, Union

import pytest

Expand Down Expand Up @@ -311,3 +311,41 @@ class Out(Model):
sync_engine.save(instance)
fetched = sync_engine.find_one(Out)
assert instance == fetched


@pytest.mark.parametrize("optional_missing", [False, True])
async def test_embedded_model_optional(aio_engine: AIOEngine, optional_missing: bool):
class E(EmbeddedModel):
f: int

class M(Model):
e: Union[E, None]

m = M() if optional_missing else M(e=E(f=3))
await aio_engine.save(m)
fetched = await aio_engine.find_one(M)
assert fetched == m

doc: Dict[str, Any] = {"_id": str(m.id)}
if not optional_missing:
doc["e"] = {"f": 3}
assert M.parse_doc(doc) == m


@pytest.mark.parametrize("optional_missing", [False, True])
def test_sync_embedded_model_optional(sync_engine: SyncEngine, optional_missing: bool):
class E(EmbeddedModel):
f: int

class M(Model):
e: Union[E, None]

m = M() if optional_missing else M(e=E(f=3))
sync_engine.save(m)
fetched = sync_engine.find_one(M)
assert fetched == m

doc: Dict[str, Any] = {"_id": str(m.id)}
if not optional_missing:
doc["e"] = {"f": 3}
assert M.parse_doc(doc) == m