Skip to content

Commit

Permalink
Update internal dict keys when primary keys are modified mid-transaction
Browse files Browse the repository at this point in the history
Both operations and unit of work internal dicts use PKs to point to
stored objects. If the PKs change mid-transaction, keys are now
updated to the new correct values.
  • Loading branch information
dtheodor committed Nov 10, 2014
1 parent 82b60f9 commit 1ef3c31
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 8 deletions.
21 changes: 19 additions & 2 deletions sqlalchemy_continuum/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@

import six
import sqlalchemy as sa
from sqlalchemy_utils import identity
from sqlalchemy_utils import identity, get_primary_keys, has_changes

from .utils import commited_identity

class Operation(object):
INSERT = 0
Expand Down Expand Up @@ -98,6 +99,7 @@ def add_update(self, target):
del state_copy[rel_key]

if state_copy:
self._sanitize_keys(target)
key = self.format_key(target)
# if the object has already been added with an INSERT,
# then this is a modification within the same transaction and
Expand All @@ -111,4 +113,19 @@ def add_update(self, target):
self.add(Operation(target, operation))

def add_delete(self, target):
self.add(Operation(target, Operation.DELETE))
self.add(Operation(target, Operation.DELETE))

def _sanitize_keys(self, target):
"""The operations key for target may not be valid if this target is in
`self.objects` but its primary key has been modified. Check against that
and update the key.
"""
key = self.format_key(target)
mapper = sa.inspect(target).mapper
for pk in mapper.primary_key:
if has_changes(target, mapper.get_property_by_column(pk).key):
old_key = target.__class__, commited_identity(target)
if old_key in self.objects:
# replace old key with the new one
self.objects[key] = self.objects.pop(old_key)
break
30 changes: 24 additions & 6 deletions sqlalchemy_continuum/unit_of_work.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from copy import copy

import sqlalchemy as sa
from sqlalchemy_utils import get_primary_keys, identity
from sqlalchemy_utils import get_primary_keys, identity, has_changes
from .operation import Operations
from .utils import (
end_tx_column_name,
version_class,
is_session_modified,
tx_column_name,
versioned_column_properties
versioned_column_properties,
commited_identity
)


Expand Down Expand Up @@ -123,19 +124,36 @@ def create_transaction(self, session):
session.add(self.current_transaction)
return self.current_transaction

def _sanitize_obj_key(self, target):
"""
The key for target in `self.version_objs` may not be valid if its
primary key has been modified. Check against that and update the key.
"""
key = self._create_key(target, identity(target))
mapper = sa.inspect(target).mapper
for pk in mapper.primary_key:
if has_changes(target, mapper.get_property_by_column(pk).key):
old_key = self._create_key(target, commited_identity(target))
if old_key in self.version_objs:
# replace old key with the new one
self.version_objs[key] = self.version_objs.pop(old_key)
break
return key

def _create_key(self, target, pks):
return version_class(target.__class__), (pks, self.current_transaction.id)

def get_or_create_version_object(self, target):
"""
Return version object for given parent object. If no version object
exists for given parent object, create one.
:param target: Parent object to create the version object for
"""
version_cls = version_class(target.__class__)
version_id = identity(target) + (self.current_transaction.id, )
version_key = (version_cls, version_id)
version_key = self._sanitize_obj_key(target)

if version_key not in self.version_objs:
version_obj = version_cls()
version_obj = version_class(target.__class__)()
self.version_objs[version_key] = version_obj
self.version_session.add(version_obj)
tx_column = self.manager.option(
Expand Down
17 changes: 17 additions & 0 deletions sqlalchemy_continuum/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,3 +414,20 @@ def changeset(obj):
if new_value:
data[prop.key] = [new_value, old_value]
return data


def commited_identity(obj):
"""Returns a tuple of the primary keys of the object without any
modifications that may have occured within the session
"""
old_pks = []
obj_inspect = sa.inspect(obj)
mapper = obj_inspect.mapper
for column in get_primary_keys(obj).itervalues():
old_pk = obj_inspect.attrs.get(
mapper.get_property_by_column(column).key).history.deleted
if old_pk:
old_pks.append(old_pk[0])
else:
old_pks.append(getattr(obj, column.name))
return tuple(old_pks)

0 comments on commit 1ef3c31

Please sign in to comment.