Skip to content

Commit

Permalink
fix: models get_or_create keyerror (#1584)
Browse files Browse the repository at this point in the history
* add tests

* add changelog

* add changelog

* fix

* throw

* fix test

* Remove unnecessary import and update method doc

* More test cases for update_or_create method

* refactor: only query once before create for update_or_create method

* Update changelog and fix mssql ci error

* fix codacy issues

---------

Co-authored-by: Waket Zheng <[email protected]>
  • Loading branch information
jiangying000 and waketzheng authored Jun 16, 2024
1 parent 3c36151 commit 6b01815
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 11 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ Changelog
0.21
====

0.21.4
------
Fixed
^^^^^
- Fix `update_or_create` errors when field value changed. (#1584)

0.21.3
------
Fixed
Expand Down
34 changes: 32 additions & 2 deletions tests/test_model_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,9 @@ async def test_implicit_clone_pk_required_none(self):
class TestModelMethods(test.TestCase):
async def asyncSetUp(self):
await super().asyncSetUp()
self.mdl = await Tournament.create(name="Test")
self.mdl2 = Tournament(name="Test")
self.cls = Tournament
self.mdl = await self.cls.create(name="Test")
self.mdl2 = self.cls(name="Test")

async def test_save(self):
oldid = self.mdl.id
Expand Down Expand Up @@ -176,6 +176,36 @@ async def test_update_or_create(self):
mdl2 = await self.cls.get(name="Test2")
self.assertEqual(mdl, mdl2)

async def test_update_or_create_with_defaults(self):
mdl = await self.cls.get(name=self.mdl.name)
mdl_dict = dict(mdl)
oldid = mdl.id
mdl.id = 135
with self.assertRaisesRegex(ParamsError, "Conflict value with key='id':"):
# Missing query: check conflict with kwargs and defaults before create
await self.cls.update_or_create(id=mdl.id, defaults=mdl_dict)
desc = str(uuid4())
# If there is no conflict with defaults and kwargs, it will be success to update or create
defaults = dict(mdl_dict, desc=desc)
kwargs = {"id": defaults["id"], "name": defaults["name"]}
mdl, created = await self.cls.update_or_create(defaults, **kwargs)
self.assertFalse(created)
self.assertEqual(defaults["desc"], mdl.desc)
self.assertNotEqual(self.mdl.desc, mdl.desc)
# Hint query: use defauts to update without checking conflict
mdl2, created = await self.cls.update_or_create(
id=oldid, desc=desc, defaults=dict(mdl_dict, desc="new desc")
)
self.assertFalse(created)
self.assertNotEqual(dict(mdl), dict(mdl2))
# Missing query: success to create if no conflict
not_exist_name = str(uuid4())
no_conflict_defaults = {"name": not_exist_name, "desc": desc}
no_conflict_kwargs = {"name": not_exist_name}
mdl, created = await self.cls.update_or_create(no_conflict_defaults, **no_conflict_kwargs)
self.assertTrue(created)
self.assertEqual(not_exist_name, mdl.name)

async def test_first(self):
mdl = await self.cls.first()
self.assertEqual(self.mdl.id, mdl.id)
Expand Down
31 changes: 22 additions & 9 deletions tortoise/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,22 +1026,35 @@ async def get_or_create(
:param using_db: Specific DB connection to use instead of default bound
:param kwargs: Query parameters.
:raises IntegrityError: If create failed
:raises TransactionManagementError: If transaction error
:raises ParamsError: If defaults conflict with kwargs
"""
if not defaults:
defaults = {}
db = using_db or cls._choose_db(True)
try:
return await cls.filter(**kwargs).using_db(db).get(), False
except DoesNotExist:
return await cls._create_or_get(db, defaults, **kwargs)

@classmethod
async def _create_or_get(
cls, db: BaseDBAsyncClient, defaults: dict, **kwargs
) -> Tuple[Self, bool]:
"""Try to create, if fails with IntegrityError then try to get"""
for key in defaults.keys() & kwargs.keys():
if (default_value := defaults[key]) != (query_value := kwargs[key]):
raise ParamsError(f"Conflict value with {key=}: {default_value=} vs {query_value=}")
merged_defaults = {**kwargs, **defaults}
try:
async with in_transaction(connection_name=db.connection_name) as connection:
return await cls.create(using_db=connection, **merged_defaults), True
except IntegrityError as exc:
try:
async with in_transaction(connection_name=db.connection_name) as connection:
return await cls.create(using_db=connection, **defaults, **kwargs), True
except IntegrityError as exc:
try:
return await cls.filter(**kwargs).using_db(db).get(), False
except DoesNotExist:
pass
raise exc
return await cls.filter(**kwargs).using_db(db).get(), False
except DoesNotExist:
pass
raise exc

@classmethod
def select_for_update(
Expand Down Expand Up @@ -1084,7 +1097,7 @@ async def update_or_create(
if instance:
await instance.update_from_dict(defaults).save(using_db=connection)
return instance, False
return await cls.get_or_create(defaults, db, **kwargs)
return await cls._create_or_get(db, defaults, **kwargs)

@classmethod
async def create(
Expand Down

0 comments on commit 6b01815

Please sign in to comment.