diff --git a/.gitignore b/.gitignore index a015a2d6..9b3348a9 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,6 @@ nosetests.xml .mr.developer.cfg .project .pydevproject + +#Pycharm +.idea \ No newline at end of file diff --git a/sqlalchemy_continuum/dialects/postgresql.py b/sqlalchemy_continuum/dialects/postgresql.py index f24d9077..f6a0ed01 100644 --- a/sqlalchemy_continuum/dialects/postgresql.py +++ b/sqlalchemy_continuum/dialects/postgresql.py @@ -29,6 +29,63 @@ WHERE NOT EXISTS (SELECT 1 FROM upsert); """ +update_upsert_cte_sql = """ +WITH upsert as +( + UPDATE {version_table_name} + SET {update_values} + WHERE + {transaction_column} = transaction_id_value + AND + ( + ({old_primary_key_criteria}) + OR + ({new_primary_key_criteria}) + ) + RETURNING * +) +INSERT INTO {version_table_name} +({transaction_column}, {operation_type_column}, {column_names}) +SELECT + transaction_id_value, + {operation_type}, + {insert_values} +WHERE NOT EXISTS (SELECT 1 FROM upsert); +""" + +delete_upsert_cte_sql = """ +WITH delete_stale as ( + DELETE FROM {version_table_name} + WHERE + {transaction_column} = transaction_id_value + AND + {primary_key_criteria} + AND + "operation_type" = 0 + RETURNING * +), upsert as +( + UPDATE {version_table_name} + SET {update_values} + WHERE + {transaction_column} = transaction_id_value + AND + {primary_key_criteria} + RETURNING * +) +INSERT INTO {version_table_name} +({transaction_column}, {operation_type_column}, {column_names}) +SELECT + transaction_id_value, + {operation_type}, + {insert_values} +WHERE + NOT EXISTS (SELECT 1 from delete_stale) + AND + NOT EXISTS (SELECT 1 FROM upsert); +""" + + temporary_transaction_sql = """ CREATE TEMP TABLE IF NOT EXISTS {temporary_transaction_table} ({transaction_table_columns}) @@ -205,6 +262,8 @@ class UpsertSQL(SQLConstruct): 'primary_key_criteria': ' AND ', } + upsert_cte_sql = upsert_cte_sql + def __init__(self, *args, **kwargs): SQLConstruct.__init__(self, *args, **kwargs) @@ -267,12 +326,29 @@ def __str__(self): for key, join_operator in self.builders.items(): params[key] = join_operator.join(getattr(self, key)) - sql = upsert_cte_sql.format(**params) + sql = self.upsert_cte_sql.format(**params) return sql class DeleteUpsertSQL(UpsertSQL): + """ + AFTER DELETE on parent_table: + exists (OLD.[pks], current_tx_id) in version_table? + No: + INSERT with operation_type = 2 + Yes: + we have one of the following scenarios: + if existing operation_type = 0, DELETE version entry + (means that a new record has been created but now is being deleted + in the same transaction. No version should be created) + if existing operation_type = 1, UPDATE with operation_type = 2 + (an object has been updated but is now being deleted) + if existing operation_type == 2, UPDATE with operation_type = 2 + (not sure if this can happen, however a second DELETE of the same + PKs still results in a DELETE) + """ operation_type = 2 + upsert_cte_sql = delete_upsert_cte_sql def build_primary_key_criteria(self): return [ @@ -284,7 +360,7 @@ def build_mod_tracking_values(self): return ['True'] * len(self.columns_without_pks) def build_update_values(self): - return [ + return ['%s = 2' % self.operation_type_column_name] + [ '"{name}" = OLD."{name}"'.format(name=c.name) for c in self.columns ] @@ -294,6 +370,16 @@ def build_values(self): class InsertUpsertSQL(UpsertSQL): + """ + AFTER INSERT on parent_table: + exists (NEW.[pks], current_tx_id) in version_table? + No: + INSERT with operation_type = 0 + Yes: + (means target was deleted and re-inserted in the same transaction, + so its actually an update) + UPDATE with operation_type = 1 + """ operation_type = 0 def build_mod_tracking_values(self): @@ -301,7 +387,37 @@ def build_mod_tracking_values(self): class UpdateUpsertSQL(UpsertSQL): + """ + AFTER UPDATE on parent_table: + exists (OLD.[pks] OR NEW.[pks], current_tx_id) in version_table? + (Normally we expect to find the OLD.[pks] in the version table. + However, the NEW.[pks] can already exist in the version table in the + following edge case: + the record with the OLD pks was deleted and a different record was updated + to highjack the OLD pks. In this case the value of OLD.[pks] is irrelevant) + No: + INSERT with operation_type = 1 + Yes: + we have one of the following scenarios: + if existing operation_type = 0 UPDATE with operation_type = 0 + (means that the record is new in this transaction but is being + updated. its version should remain with operation type INSERT) + if existing operation_type = 1 UPDATE with operation_type = 1 + (a 2nd update to the same object in the same transaction) + if existing operation_type = 2 UPDATE with operation_type = 1 + (this is the case described in the justification of why we check for + NEW.[pks]. This is also an UPDATE) + """ operation_type = 1 + upsert_cte_sql = update_upsert_cte_sql + + @property + def builders(self): + builders = super(UpdateUpsertSQL, self).builders.copy() + del builders['primary_key_criteria'] + builders.update(old_primary_key_criteria=' AND ', + new_primary_key_criteria=' AND ') + return builders def build_mod_tracking_values(self): return [ @@ -309,6 +425,45 @@ def build_mod_tracking_values(self): .format(c.name) for c in self.columns_without_pks ] + def build_new_primary_key_criteria(self): + return [ + '"{name}" = NEW."{name}"'.format(name=c.name) + for c in self.columns if c.primary_key + ] + + def build_old_primary_key_criteria(self): + return [ + '"{name}" = OLD."{name}"'.format(name=c.name) + for c in self.columns if c.primary_key + ] + + def build_update_values(self): + parent_columns = [ + '"{name}" = NEW."{name}"'.format(name=c.name) + for c in self.columns + ] + mod_columns = [] + if self.use_property_mod_tracking: + mod_columns = [ + '{0}_mod = {0}_mod OR OLD."{0}" IS DISTINCT FROM NEW."{0}"' + .format(c.name) + for c in self.columns_without_pks + ] + + operation_type_update = """{operation_type} = ( + CASE + WHEN {operation_type} = 2 THEN 1 + WHEN {operation_type} = 0 THEN 0 + ELSE 1 + END +)""".format(operation_type=self.operation_type_column_name) + + return ( + [operation_type_update] + + parent_columns + + mod_columns + ) + class ValiditySQL(SQLConstruct): @property diff --git a/sqlalchemy_continuum/operation.py b/sqlalchemy_continuum/operation.py index 72c72b2d..05728ffa 100644 --- a/sqlalchemy_continuum/operation.py +++ b/sqlalchemy_continuum/operation.py @@ -6,13 +6,13 @@ import six import sqlalchemy as sa -from sqlalchemy_utils import identity - +from sqlalchemy_utils import identity, get_primary_keys, has_changes class Operation(object): INSERT = 0 UPDATE = 1 DELETE = 2 + STALE_VERSION = -1 def __init__(self, target, type): self.target = target @@ -98,7 +98,40 @@ def add_update(self, target): del state_copy[rel_key] if state_copy: - self.add(Operation(target, Operation.UPDATE)) + self._sanitize_keys(target) + key = self.format_key(target) + # if the object has already been added with an INSERT, + # then this is a modification within the same transaction and + # this is still an INSERT + if (target in self and + self[key].type == Operation.INSERT): + operation = Operation.INSERT + else: + operation = Operation.UPDATE + + self.add(Operation(target, operation)) def add_delete(self, target): - self.add(Operation(target, Operation.DELETE)) + if target in self and \ + self[self.format_key(target)].type == Operation.INSERT: + # if the target's existing operation is INSERT, it is being + # deleted within the same transaction and no version entry + # should be persisted + self.add(Operation(target, Operation.STALE_VERSION)) + else: + self.add(Operation(target, Operation.DELETE)) + + def _sanitize_keys(self, target): + """The operations key for target may not be valid if this target is in + `self.objects` but its primary key has been modified. Check against that + and update the key. + """ + key = self.format_key(target) + mapper = sa.inspect(target).mapper + for pk in mapper.primary_key: + if has_changes(target, mapper.get_property_by_column(pk).key): + old_key = target.__class__, sa.inspect(target).identity + if old_key in self.objects: + # replace old key with the new one + self.objects[key] = self.objects.pop(old_key) + break \ No newline at end of file diff --git a/sqlalchemy_continuum/plugins/transaction_changes.py b/sqlalchemy_continuum/plugins/transaction_changes.py index cece8d5b..f665ca9a 100644 --- a/sqlalchemy_continuum/plugins/transaction_changes.py +++ b/sqlalchemy_continuum/plugins/transaction_changes.py @@ -107,6 +107,7 @@ def after_version_class_built(self, parent_cls, version_cls): primaryjoin=( self.model_class.transaction_id == transaction_column ), - foreign_keys=[self.model_class.transaction_id] + foreign_keys=[self.model_class.transaction_id], + passive_deletes='all' ) parent_cls.__versioned__['transaction_changes'] = self.model_class diff --git a/sqlalchemy_continuum/unit_of_work.py b/sqlalchemy_continuum/unit_of_work.py index 49d89f48..0bac7789 100644 --- a/sqlalchemy_continuum/unit_of_work.py +++ b/sqlalchemy_continuum/unit_of_work.py @@ -1,8 +1,8 @@ from copy import copy import sqlalchemy as sa -from sqlalchemy_utils import get_primary_keys, identity -from .operation import Operations +from sqlalchemy_utils import get_primary_keys, identity, has_changes +from .operation import Operations, Operation from .utils import ( end_tx_column_name, version_class, @@ -123,6 +123,25 @@ def create_transaction(self, session): session.add(self.current_transaction) return self.current_transaction + def _sanitize_obj_key(self, target): + """ + The key for target in `self.version_objs` may not be valid if its + primary key has been modified. Check against that and update the key. + """ + key = self._create_key(target, identity(target)) + mapper = sa.inspect(target).mapper + for pk in mapper.primary_key: + if has_changes(target, mapper.get_property_by_column(pk).key): + old_key = self._create_key(target, sa.inspect(target).identity) + if old_key in self.version_objs: + # replace old key with the new one + self.version_objs[key] = self.version_objs.pop(old_key) + break + return key + + def _create_key(self, target, pks): + return version_class(target.__class__), (pks, self.current_transaction.id) + def get_or_create_version_object(self, target): """ Return version object for given parent object. If no version object @@ -130,12 +149,10 @@ def get_or_create_version_object(self, target): :param target: Parent object to create the version object for """ - version_cls = version_class(target.__class__) - version_id = identity(target) + (self.current_transaction.id, ) - version_key = (version_cls, version_id) + version_key = self._sanitize_obj_key(target) if version_key not in self.version_objs: - version_obj = version_cls() + version_obj = version_class(target.__class__)() self.version_objs[version_key] = version_obj self.version_session.add(version_obj) tx_column = self.manager.option( @@ -151,6 +168,20 @@ def get_or_create_version_object(self, target): else: return self.version_objs[version_key] + def delete_version_object(self, target): + """ + Delete version object for `target` parent object, if a version object + exists. + + :param target: Parent object for which the version object should be + removed + """ + version_key = self._sanitize_obj_key(target) + version_obj = self.version_objs.pop(version_key, None) + if version_obj is not None: + self.version_session.delete(version_obj) + + def process_operation(self, operation): """ Process given operation object. The operation processing has x stages: @@ -164,18 +195,21 @@ def process_operation(self, operation): :param operation: Operation object """ target = operation.target - version_obj = self.get_or_create_version_object(target) - version_obj.operation_type = operation.type - self.assign_attributes(target, version_obj) + if operation.type == Operation.STALE_VERSION: + self.delete_version_object(target) + else: + version_obj = self.get_or_create_version_object(target) + version_obj.operation_type = operation.type + self.assign_attributes(target, version_obj) - self.manager.plugins.after_create_version_object( - self, target, version_obj - ) - if self.manager.option(target, 'strategy') == 'validity': - self.update_version_validity( - target, - version_obj + self.manager.plugins.after_create_version_object( + self, target, version_obj ) + if self.manager.option(target, 'strategy') == 'validity': + self.update_version_validity( + target, + version_obj + ) operation.processed = True def create_version_objects(self, session): diff --git a/tests/test_delete.py b/tests/test_delete.py index 3934d2d7..ae5f3c26 100644 --- a/tests/test_delete.py +++ b/tests/test_delete.py @@ -1,4 +1,6 @@ import sqlalchemy as sa +from sqlalchemy_continuum import Operation, version_class + from tests import TestCase @@ -25,6 +27,57 @@ def test_creates_versions_on_delete(self): assert versions[1].name == u'Some article' assert versions[1].content == u'Some content' + def test_insert_delete_in_single_transaction(self): + """Test that when an object is created and then deleted within the + same transaction, no history entry is created. + """ + article = self.Article(name=u'Article name') + self.session.add(article) + self.session.flush() + + self.session.delete(article) + self.session.commit() + + ArticleVersion = version_class(self.Article) + assert self.session.query(ArticleVersion).count() == 0 + + def test_update_delete_in_single_transaction(self): + """Test that when an object is updated and then deleted within the + same transaction, the operation type DELETE is stored. + """ + article = self.Article(name=u'Article name') + self.session.add(article) + self.session.commit() + + article.name = u'Updated name' + self.session.flush() + self.session.delete(article) + self.session.commit() + + ArticleVersion = version_class(self.Article) + versions_query = self.session.query(self.ArticleVersion) + assert versions_query.count() == 2 + assert versions_query[1].operation_type == Operation.DELETE + + def test_modify_primary_key(self): + """Test that modifying the primary key within the same transaction + maintains correct delete behavior""" + article = self.Article(name=u'Article name') + self.session.add(article) + self.session.commit() + + article.name = u'Second name' + self.session.flush() + article.id += 1 + self.session.delete(article) + self.session.commit() + + ArticleVersion = version_class(self.Article) + versions_q = self.session.query(ArticleVersion)\ + .order_by(ArticleVersion.transaction_id) + assert versions_q.count() == 2 + assert versions_q[1].operation_type == Operation.DELETE + class TestDeleteWithDeferredColumn(TestCase): def create_models(self): diff --git a/tests/test_exotic_operation_combos.py b/tests/test_exotic_operation_combos.py index 8a24f947..3f25a816 100644 --- a/tests/test_exotic_operation_combos.py +++ b/tests/test_exotic_operation_combos.py @@ -40,6 +40,9 @@ def test_insert_deleted_and_flushed_object(self): assert article2.versions[1].operation_type == 1 def test_replace_deleted_object_with_update(self): + """Test that deleting an object and hijacking its primary key results + in turning the operation_type = 2 to an operation_type = 1 + """ article = self.Article() article.name = u'Some article' article.content = u'Some content' diff --git a/tests/test_insert.py b/tests/test_insert.py index 12b0bf04..b12fdd3d 100644 --- a/tests/test_insert.py +++ b/tests/test_insert.py @@ -1,5 +1,6 @@ import sqlalchemy as sa -from sqlalchemy_continuum import count_versions, versioning_manager +from sqlalchemy_continuum import count_versions, versioning_manager, \ + Operation, version_class from tests import TestCase @@ -39,6 +40,36 @@ def test_multiple_consecutive_flushes(self): assert article.versions.count() == 1 assert article2.versions.count() == 1 + def test_multiple_flushes_store_operation_type(self): + """Test that after multiple flushes that affect a newly created object, + the insert operation type is commited + """ + article = self.Article(name=u'Article name') + self.session.add(article) + self.session.flush() + article.name = u'Changed my mind' + self.session.commit() + assert article.versions.count() == 1 + assert article.versions[0].operation_type == Operation.INSERT + + def test_modify_primary_key(self): + """Test that modifying the primary key within the insert transaction + maintains correct insert behavior""" + article = self.Article(name=u'Article name') + self.session.add(article) + self.session.flush() + article.id += 1 + self.session.commit() + assert article.versions.count() == 1 + assert article.versions[-1].operation_type == Operation.INSERT + + # also check that no additional article versions have leaked... + ArticleVersion = version_class(self.Article) + versions_query = self.session.query(ArticleVersion)\ + .order_by(ArticleVersion.transaction_id) + assert versions_query.count() == 1 + assert versions_query[0].operation_type == Operation.INSERT + class TestInsertWithDeferredColumn(TestCase): def create_models(self): diff --git a/tests/test_update.py b/tests/test_update.py index 29e0e53c..e8bb2b99 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -1,4 +1,6 @@ import sqlalchemy as sa +from sqlalchemy_continuum import Operation, version_class + from tests import TestCase @@ -76,6 +78,26 @@ def test_multiple_updates_within_same_transaction(self): assert version.name == u'Some article' assert version.content == u'Updated content 2' + def test_modify_primary_key(self): + """Test that modifying the primary key within the same transaction + maintains correct update behavior""" + article = self.Article(name=u'Article name') + self.session.add(article) + self.session.commit() + + article.name = u'Second name' + self.session.flush() + article.id += 1 + self.session.commit() + + assert article.versions.count() == 1 + + ArticleVersion = version_class(self.Article) + versions_q = self.session.query(ArticleVersion)\ + .order_by(ArticleVersion.transaction_id) + assert versions_q.count() == 2 + assert versions_q[1].operation_type == Operation.UPDATE + class TestUpdateWithDefaultValues(TestCase): def create_models(self):