diff --git a/.gitignore b/.gitignore
index a015a2d6..e76e7c49 100644
--- a/.gitignore
+++ b/.gitignore
@@ -34,3 +34,11 @@ nosetests.xml
.mr.developer.cfg
.project
.pydevproject
+
+# mypy
+.mypy_cache/
+
+# Unit test / coverage reports
+.cache
+
+\.idea/
diff --git a/.travis.yml b/.travis.yml
index 0b27edb1..d5187594 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -4,18 +4,21 @@ addons:
env:
- DB=mysql
- DB=postgres
+ - DB=postgres-native
- DB=sqlite
before_script:
- sh -c "if [ '$DB' = 'postgres' ]; then psql -c 'create database sqlalchemy_continuum_test;' -U postgres; fi"
+ - sh -c "if [ '$DB' = 'postgres-native' ]; then psql -c 'create database sqlalchemy_continuum_test;' -U postgres; fi"
- sh -c "if [ '$DB' = 'mysql' ]; then mysql -e 'create database sqlalchemy_continuum_test;'; fi"
language: python
python:
- - 2.6
- 2.7
- - 3.3
+ - 3.4
+ - 3.5
+ - 3.6
install:
- pip install -e ".[test]"
script:
- - python setup.py test
+ - py.test
diff --git a/CHANGES.rst b/CHANGES.rst
index b4d97e96..1d98e6dd 100644
--- a/CHANGES.rst
+++ b/CHANGES.rst
@@ -3,9 +3,211 @@ Changelog
Here you can see the full list of changes between each SQLAlchemy-Continuum release.
-1.0-b1 (2014-03-14)
+
+1.3.6 (2018-07-30)
+^^^^^^^^^^^^^^^^^^
+
+- Fixed ResourceClosedErrors from connections leaking when using an external transaction (#196, courtesy of vault)
+
+
+1.3.5 (2018-06-03)
+^^^^^^^^^^^^^^^^^^
+
+- Track cloned connections (#167, courtesy of netcriptus)
+
+
+1.3.4 (2018-03-07)
+^^^^^^^^^^^^^^^^^^
+
+- Exclude many-to-many properties from versioning if they are added in exclude parameter (#169, courtesy of fuhrysteve)
+
+
+1.3.3 (2017-11-05)
+^^^^^^^^^^^^^^^^^^
+
+- Fixed changeset when updating object in same transaction as inserting it (#141, courtesy of oinopion)
+
+
+1.3.2 (2017-10-12)
+^^^^^^^^^^^^^^^^^^
+
+- Fixed multiple schema handling (#132, courtesy of vault)
+
+
+1.3.1 (2017-06-28)
+^^^^^^^^^^^^^^^^^^
+
+- Fixed subclass retrieval for closest_matching_table (#163, courtesy of debonzi)
+
+
+1.3.0 (2017-01-30)
+^^^^^^^^^^^^^^^^^^
+
+- Dropped py2.6 support
+- Fixed memory leaks with UnitOfWork instances (#131, courtesy of quantus)
+
+
+1.2.4 (2016-01-10)
+^^^^^^^^^^^^^^^^^^
+
+- Added explicit sequence names for Oracle (#118, courtesy of apfeiffer1)
+
+
+1.2.3 (2016-01-10)
+^^^^^^^^^^^^^^^^^^
+
+- Added use_module_name configuration option (#119, courtesy of kyheo)
+
+
+1.2.2 (2015-12-08)
+^^^^^^^^^^^^^^^^^^
+
+- Fixed some relationship changes not counted as modifications (#116, courtesy of tvuotila)
+
+
+1.2.1 (2015-09-27)
+^^^^^^^^^^^^^^^^^^
+
+- Fixed deep joined table inheritance handling (#105, courtesy of piotr-dobrogost)
+- Fixed naive assumption of related User model always having id column (#107, courtesy of avilaton)
+- Fixed one-to-many relationship reverting (#102, courtesy of sdorazio)
+
+
+1.2.0 (2015-07-31)
+^^^^^^^^^^^^^^^^^^
+
+- Removed generated changes attribute from version classes. This attribute can be accessed through `transaction.changes`
+- Removed is_modified checking from insert operations
+
+
+1.1.5 (2014-12-28)
+^^^^^^^^^^^^^^^^^^
+
+- Added smart primary key type inspection for user class (#86, courtesy of mattupstate)
+- Added support for self-referential version relationship reflection (#88, courtesy of dtheodor)
+
+
+1.1.4 (2014-12-06)
+^^^^^^^^^^^^^^^^^^
+
+- Fixed One-To-Many version relationship handling (#82, courtesy of dtheodor)
+- Fixed Many-To-Many version relationship handling (#83, courtesy of dtheodor)
+- Fixed inclusion and exclusion of aliased columns
+- Removed automatic exclusion of auto-assigned datetime columns and tsvector columns (explicit is better than implicit)
+
+
+1.1.3 (2014-10-23)
+^^^^^^^^^^^^^^^^^^
+
+- Made FlaskPlugin accepts overriding of current_user_id_factory and remote_addr_factory
+
+
+1.1.2 (2014-10-07)
+^^^^^^^^^^^^^^^^^^
+
+- Fixed identifier quoting in trigger syncing
+
+
+1.1.1 (2014-10-07)
+^^^^^^^^^^^^^^^^^^
+
+- Fixed native versioning trigger syncing
+
+
+1.1.0 (2014-10-02)
+^^^^^^^^^^^^^^^^^^
+
+- Added Python 3.4 to test suite
+- Added optional native trigger based versioning for PostgreSQL dialect
+- Added create_models option
+- Added count_versions utility function
+- Fixed custom transaction column name handling with models using joined table inheritance
+- Fixed subquery strategy support for models using joined table inheritance
+- Fixed savepoint handling
+- Fixed version model building when no versioned models were found (previously threw AttributeError)
+- Replaced plugin template methods before_create_tx_object and after_create_tx_object with transaction_args to better cope with native versioning
+
+
+1.0.3 (2014-07-16)
+^^^^^^^^^^^^^^^^^^
+
+- Added __repr__ for Operations class
+- Fixed an issue where assigning unmodified object's attributes in user defined before flush listener would raise TypeError in UnitOfWork
+
+
+1.0.2 (2014-07-11)
+^^^^^^^^^^^^^^^^^^
+
+- Allowed easier overriding of PropertyModTracker column creation
+- Rewrote join table inheritance handling schematics (now working with SA 0.9.6)
+- SQLAlchemy-Utils dependency updated to 0.26.5
+
+
+1.0.1 (2014-06-18)
+^^^^^^^^^^^^^^^^^^
+
+- Fixed an issue where deleting an object with deferred columns would throw ObjectDeletedError.
+- Made viewonly relationships with association tables not register the association table to versioning manager registry.
+
+
+1.0 (2014-06-16)
^^^^^^^^^^^^^^^^
+- Added __repr__ for Transaction class, issue #59
+- Made transaction_cls of VersioningManager configurable.
+- Removed generic relationships from transaction class to versioned classes.
+- Removed generic relationships from transaction changes class to versioned classes.
+- Removed relation_naming_function (no longer needed)
+- Moved get_bind to SQLAlchemy-Utils
+- Removed inflection package from dependencies (no longer needed)
+- SQLAlchemy-Utils dependency updated to 0.26.2
+
+
+1.0b5 (2014-05-07)
+^^^^^^^^^^^^^^^^^^
+
+- Added order_by mapper arg ignoring for version class reflection if other than string argument is used
+- Added support for customizing the User class which the Transaction class should have relationship to (issue #53)
+- Changed get_versioning_manager to throw ClassNotVersioned exception if first argument is not a versioned class
+- Fixed relationship reflection from versioned classes to non versioned classes (issue #52)
+- SQLAlchemy-Utils dependency updated to 0.25.4
+
+
+1.0-b4 (2014-04-20)
+^^^^^^^^^^^^^^^^^^^
+
+- Fixed many-to-many unit of work inspection when using engine bind instead of collection bind
+- Fixed various issues if primary key aliases were used in declarative models
+- Fixed an issue where association versioning would not work with custom transaction column name
+- SQLAlchemy-Utils dependency updated to 0.25.3
+
+
+1.0-b3 (2014-04-19)
+^^^^^^^^^^^^^^^^^^^
+
+- Added support for concrete inheritance
+- Added order_by mapper arg reflection to version classes
+- Added support for column_prefix mapper arg
+- Made model builder copy inheritance mapper args to version classes from parent classes
+- Fixed end transaction id setting for join table inheritance classes. Now end transaction id is set explicitly to all tables in inheritance hierarchy.
+- Fixed single table inheritance handling
+
+
+1.0-b2 (2014-04-09)
+^^^^^^^^^^^^^^^^^^^
+
+- Added some schema tools to help migrating between different plugins and versioning strategies
+- Added remove_versioning utility function, see issue #45
+- Added order_by transaction_id default to versions relationship
+- Fixed PropertyModTrackerPlugin association table handling.
+- Fixed get_bind schematics (Flask-SQLAlchemy integration wasn't working)
+- Fixed a bug where committing a session without objects would result in KeyError
+- SQLAlchemy dependency updated to 0.9.4
+
+
+1.0-b1 (2014-03-14)
+^^^^^^^^^^^^^^^^^^^
+
- Added new plugin architecture
- Added ActivityPlugin
- Naming conventions change: History -> Version (to be consistent throughout Continuum)
diff --git a/README.rst b/README.rst
index ac90c698..b6a2bda0 100644
--- a/README.rst
+++ b/README.rst
@@ -16,6 +16,7 @@ Features
- Transactions can be queried afterwards using SQLAlchemy query syntax
- Query for changed records at given transaction
- Temporal relationship reflection. Version object's relationship show the parent objects relationships as they where in that point in time.
+- Supports native versioning for PostgreSQL database (trigger based versioning)
QuickStart
@@ -40,7 +41,7 @@ In order to make your models versioned you need two things:
from sqlalchemy_continuum import make_versioned
- make_versioned()
+ make_versioned(user_cls=None)
class Article(Base):
@@ -77,7 +78,7 @@ In order to make your models versioned you need two things:
Resources
---------
-- `Documentation `_
+- `Documentation `_
- `Issue Tracker `_
- `Code `_
@@ -87,7 +88,20 @@ Resources
.. |Build Status| image:: https://travis-ci.org/kvesteri/sqlalchemy-continuum.png?branch=master
:target: https://travis-ci.org/kvesteri/sqlalchemy-continuum
-.. |Version Status| image:: https://pypip.in/v/SQLAlchemy-Continuum/badge.png
- :target: https://crate.io/packages/SQLAlchemy-Continuum/
-.. |Downloads| image:: https://pypip.in/d/SQLAlchemy-Continuum/badge.png
- :target: https://crate.io/packages/SQLAlchemy-Continuum/
+.. |Version Status| image:: https://img.shields.io/pypi/v/SQLAlchemy-Continuum.png
+ :target: https://pypi.python.org/pypi/SQLAlchemy-Continuum/
+.. |Downloads| image:: https://img.shields.io/pypi/dm/SQLAlchemy-Continuum.png
+ :target: https://pypi.python.org/pypi/SQLAlchemy-Continuum/
+
+
+More information
+----------------
+
+- http://en.wikipedia.org/wiki/Slowly_changing_dimension
+- http://en.wikipedia.org/wiki/Change_data_capture
+- http://en.wikipedia.org/wiki/Anchor_Modeling
+- http://en.wikipedia.org/wiki/Shadow_table
+- https://wiki.postgresql.org/wiki/Audit_trigger
+- https://wiki.postgresql.org/wiki/Audit_trigger_91plus
+- http://kosalads.blogspot.fi/2014/06/implement-audit-functionality-in.html
+- https://github.com/2ndQuadrant/pgaudit
diff --git a/benchmark.py b/benchmark.py
new file mode 100644
index 00000000..9ee0250f
--- /dev/null
+++ b/benchmark.py
@@ -0,0 +1,140 @@
+import itertools as it
+import warnings
+from copy import copy
+from time import time
+
+import sqlalchemy as sa
+from sqlalchemy import create_engine
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm import sessionmaker
+from sqlalchemy_continuum import (
+ make_versioned,
+ versioning_manager,
+ remove_versioning
+)
+from sqlalchemy_continuum.transaction import TransactionFactory
+from sqlalchemy_continuum.plugins import (
+ PropertyModTrackerPlugin,
+ TransactionMetaPlugin,
+ TransactionChangesPlugin
+)
+from termcolor import colored
+
+warnings.simplefilter('error', sa.exc.SAWarning)
+
+
+def test_versioning(
+ native_versioning,
+ versioning_strategy,
+ property_mod_tracking
+):
+ transaction_column_name = 'transaction_id'
+ end_transaction_column_name = 'end_transaction_id'
+ plugins = [TransactionChangesPlugin(), TransactionMetaPlugin()]
+
+ if property_mod_tracking:
+ plugins.append(PropertyModTrackerPlugin())
+ transaction_cls = TransactionFactory()
+ user_cls = None
+
+ Model = declarative_base()
+
+ options = {
+ 'create_models': True,
+ 'native_versioning': native_versioning,
+ 'base_classes': (Model, ),
+ 'strategy': versioning_strategy,
+ 'transaction_column_name': transaction_column_name,
+ 'end_transaction_column_name': end_transaction_column_name,
+ }
+
+ make_versioned(options=options)
+
+ dns = 'postgres://postgres@localhost/sqlalchemy_continuum_test'
+ versioning_manager.plugins = plugins
+ versioning_manager.transaction_cls = transaction_cls
+ versioning_manager.user_cls = user_cls
+
+ engine = create_engine(dns)
+ # engine.echo = True
+
+ class Article(Model):
+ __tablename__ = 'article'
+ __versioned__ = copy(options)
+
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ name = sa.Column(sa.Unicode(255), nullable=False)
+ content = sa.Column(sa.UnicodeText)
+ description = sa.Column(sa.UnicodeText)
+
+ class Tag(Model):
+ __tablename__ = 'tag'
+ __versioned__ = copy(options)
+
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ article_id = sa.Column(sa.Integer, sa.ForeignKey(Article.id))
+ article = sa.orm.relationship(Article, backref='tags')
+
+
+ sa.orm.configure_mappers()
+
+ connection = engine.connect()
+
+ Model.metadata.create_all(connection)
+
+ Session = sessionmaker(bind=connection)
+ session = Session(autoflush=False)
+ session.execute('CREATE EXTENSION IF NOT EXISTS hstore')
+
+ Model.metadata.create_all(connection)
+
+ start = time()
+
+ for i in range(20):
+ for i in range(20):
+ session.add(Article(name=u'Article', tags=[Tag(), Tag()]))
+ session.commit()
+
+ print 'Testing with:'
+ print ' native_versioning=%r' % native_versioning
+ print ' versioning_strategy=%r' % versioning_strategy
+ print ' property_mod_tracking=%r' % property_mod_tracking
+ print colored('%r seconds' % (time() - start), 'red')
+
+ Model.metadata.drop_all(connection)
+
+ remove_versioning()
+ versioning_manager.reset()
+
+ session.close_all()
+ session.expunge_all()
+ Model.metadata.drop_all(connection)
+ engine.dispose()
+ connection.close()
+
+
+
+setting_variants = {
+ 'versioning_strategy': [
+ 'subquery',
+ 'validity',
+ ],
+ 'native_versioning': [
+ True,
+ False
+ ],
+ 'property_mod_tracking': [
+ False,
+ True
+ ]
+}
+
+
+names = sorted(setting_variants)
+combinations = [
+ dict(zip(names, prod))
+ for prod in
+ it.product(*(setting_variants[name] for name in names))
+]
+for combination in combinations:
+ test_versioning(**combination)
diff --git a/docs/api.rst b/docs/api.rst
index 3858a650..f601a2c5 100644
--- a/docs/api.rst
+++ b/docs/api.rst
@@ -2,6 +2,8 @@ API Documentation
=================
+.. module:: sqlalchemy_continuum
+
.. autofunction:: make_versioned
diff --git a/docs/configuration.rst b/docs/configuration.rst
index f9edf1cf..37884b6e 100644
--- a/docs/configuration.rst
+++ b/docs/configuration.rst
@@ -39,7 +39,7 @@ Validity
The 'validity' strategy saves two columns in each history table, namely 'transaction_id' and 'end_transaction_id'. The names of these columns can be configured with configuration options `transaction_column_name` and `end_transaction_column_name`.
-As with 'subquery' strategy for each inserted, updated and deleted entity Continuum creates new version in the history table. However it also updates the end_transaction_id of the previous version to point at the current version. This creates a little be of overhead during data manipulation.
+As with 'subquery' strategy for each inserted, updated and deleted entity Continuum creates new version in the history table. However it also updates the end_transaction_id of the previous version to point at the current version. This creates a little bit of overhead during data manipulation.
With 'validity' strategy version traversal is very fast. When accessing previous version Continuum tries to find the version record where the primary keys match and end_transaction_id is the same as the transaction_id of the given version record. When accessing the next version Continuum tries to find the version record where the primary keys match and transaction_id is the same as the end_transaction_id of the given version record.
@@ -73,23 +73,7 @@ Cons:
Column exclusion and inclusion
------------------------------
-With `include` and `exclude` configuration options you can define which entity attributes you want to get versioned. By default Continuum versions all entity attributes except DateTime columns with default values. If you want to include this columns you have to pass them to `include`.
-
-
-::
-
-
- class User(Base):
- __versioned__ = {
- 'include': ['created_at']
- }
-
- id = sa.Column(sa.Integer, primary_key=True)
- name = sa.Column(sa.Unicode(255))
- created_at = sa.Column(sa.DateTime)
-
-
-Sometimes you may have columns you want to exclude from the history classes. You may pass the column names to `exclude` option as follows:
+With `exclude` configuration option you can define which entity attributes you want to get versioned. By default Continuum versions all entity attributes.
::
@@ -126,11 +110,6 @@ Here is a full list of configuration options:
* operation_type_column_name (default: 'operation_type')
The name of the operation type column (used by history tables).
-* relation_naming_function (default: lambda a: pluralize(underscore(a)))
- The relation naming function that is being used for generating the relationship names between various generated models.
-
- For example lets say you have versioned class called 'User'. By default Continuum builds relationship from TransactionLog with name 'users' that points to User class.
-
* strategy (default: 'validity')
The versioning strategy to use. Either 'validity' or 'subquery'
@@ -150,6 +129,27 @@ Example
content = sa.Column(sa.UnicodeText)
+Customizing transaction user class
+----------------------------------
+
+By default Continuum tries to build a relationship between 'User' class and Transaction class. If you have differently named user class you can simply pass its name to make_versioned:
+
+
+::
+
+
+ make_versioned(user_cls='MyUserClass')
+
+
+
+If you don't want transactions to contain any user references you can also disable this feature.
+
+
+::
+
+ make_versioned(user_cls=None)
+
+
Customizing versioned mappers
-----------------------------
diff --git a/docs/index.rst b/docs/index.rst
index 45cf921f..71937337 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -11,8 +11,9 @@ SQLAlchemy-Continuum is a versioning extension for SQLAlchemy.
version_objects
revert
queries
- plugins
transactions
+ native_versioning
+ plugins
configuration
schema
alembic
diff --git a/docs/intro.rst b/docs/intro.rst
index d6b94ab8..aca33422 100644
--- a/docs/intro.rst
+++ b/docs/intro.rst
@@ -5,7 +5,7 @@ Introduction
Why?
^^^^
-SQLAlchemy already has versioning extension. This extension however is very limited. It does not support versioning entire transactions.
+SQLAlchemy already has a versioning extension. This extension however is very limited. It does not support versioning entire transactions.
Hibernate for Java has Envers, which had nice features but lacks a nice API. Ruby on Rails has papertrail_, which has very nice API but lacks the efficiency and feature set of Envers.
@@ -54,7 +54,7 @@ In order to make your models versioned you need two things:
from sqlalchemy_continuum import make_versioned
- make_versioned()
+ make_versioned(user_cls=None)
class Article(Base):
@@ -82,12 +82,12 @@ When the models have been configured either by calling configure_mappers() or by
::
- from sqlalchemy_continuum import history_class, parent_class
+ from sqlalchemy_continuum import version_class, parent_class
- history_class(Article) # ArticleHistory class
+ version_class(Article) # ArticleHistory class
- parent_class(history_class(Article)) # Article class
+ parent_class(version_class(Article)) # Article class
Versions and transactions
diff --git a/docs/native_versioning.rst b/docs/native_versioning.rst
new file mode 100644
index 00000000..60e215fe
--- /dev/null
+++ b/docs/native_versioning.rst
@@ -0,0 +1,32 @@
+Native versioning
+=================
+
+As of version 1.1 SQLAlchemy-Continuum supports native versioning for PostgreSQL dialect.
+Native versioning creates SQL triggers for all versioned models. These triggers keep track of changes made to versioned models. Compared to object based versioning, native versioning has
+
+* Much faster than regular object based versioning
+* Minimal memory footprint when used alongside `create_tables=False` and `create_models=False` configuration options.
+* More cumbersome database migrations, since triggers need to be updated also.
+
+Usage
+-----
+
+For enabling native versioning you need to set `native_versioning` configuration option as `True`.
+
+::
+
+ make_versioned(options={'native_versioning': True})
+
+
+
+Schema migrations
+-----------------
+
+When making schema migrations (for example adding new columns to version tables) you need to remember to call sync_trigger in order to keep the version trigger up-to-date.
+
+::
+
+ from sqlalchemy_continuum.dialects.postgresql import sync_trigger
+
+
+ sync_trigger(conn, 'article_version')
diff --git a/docs/schema.rst b/docs/schema.rst
index 5ca4fa8b..e13ecb53 100644
--- a/docs/schema.rst
+++ b/docs/schema.rst
@@ -2,22 +2,22 @@ Continuum Schema
================
-History tables
+Version tables
--------------
-By default SQLAlchemy-Continuum creates a history table for each versioned entity table. The history tables are suffixed with '_history'. So for example if you have two versioned tables 'article' and 'category', SQLAlchemy-Continuum would create two history models 'article_history' and 'category_history'.
+By default SQLAlchemy-Continuum creates a version table for each versioned entity table. The version tables are suffixed with '_version'. So for example if you have two versioned tables 'article' and 'category', SQLAlchemy-Continuum would create two version tables 'article_version' and 'category_version'.
-By default the history tables contain these columns:
+By default the version tables contain these columns:
* id of the original entity (this can be more then one column in the case of composite primary keys)
* transaction_id - an integer that matches to the id number in the transaction_log table.
-* end_transaction_id - an integer that matches the next history record's transaction_id. If this is the current history record then this field is null.
+* end_transaction_id - an integer that matches the next version record's transaction_id. If this is the current version record then this field is null.
* operation_type - a small integer defining the type of the operation
* versioned fields from the original entity
If you are using :ref:`property-mod-tracker` Continuum also creates one boolean field for each versioned field. By default these boolean fields are suffixed with '_mod'.
-The primary key of each history table is the combination of parent table's primary key + the transaction_id. This means there can be at most one history table entry for a given entity instance at given transaction.
+The primary key of each version table is the combination of parent table's primary key + the transaction_id. This means there can be at most one version table entry for a given entity instance at given transaction.
Transaction tables
------------------
@@ -32,3 +32,13 @@ Using vacuum
.. module:: sqlalchemy_continuum
.. autofunction:: vacuum
+
+
+Schema tools
+------------
+
+.. module:: sqlalchemy_continuum.schema
+
+.. autofunction:: update_end_tx_column
+
+.. autofunction:: update_property_mod_flags
diff --git a/docs/transactions.rst b/docs/transactions.rst
index 4d91310a..d520d40b 100644
--- a/docs/transactions.rst
+++ b/docs/transactions.rst
@@ -13,7 +13,8 @@ Transaction can be queried just like any other sqlalchemy declarative model.
::
- Transaction = Article.__versioned__['transaction_class']
+ from sqlalchemy_continuum import transaction_class
+ Transaction = transaction_class(Article)
# find all transactions
session.query(Transaction).all()
@@ -72,9 +73,9 @@ This would execute the following SQL queries (on PostgreSQL)
1. INSERT INTO article (name, content) VALUES (?, ?)
params: ('Some article', 'Some content')
-2. INSERT INTO transaction_log (issued_at) VALUES (?)
+2. INSERT INTO transaction (issued_at) VALUES (?)
params: (datetime.utcnow())
-3. INSERT INTO article_history (id, name, content, transaction_id) VALUES (?, ?, ?, ?)
+3. INSERT INTO article_version (id, name, content, transaction_id) VALUES (?, ?, ?, ?)
params: (, 'Some article', 'Some content', )
diff --git a/docs/utilities.rst b/docs/utilities.rst
index 3dd70324..161fb0b3 100644
--- a/docs/utilities.rst
+++ b/docs/utilities.rst
@@ -11,6 +11,12 @@ changeset
.. autofunction:: changeset
+count_versions
+--------------
+
+.. autofunction:: count_versions
+
+
get_versioning_manager
----------------------
diff --git a/docs/version_objects.rst b/docs/version_objects.rst
index 7fe2d195..ecbb71ce 100644
--- a/docs/version_objects.rst
+++ b/docs/version_objects.rst
@@ -132,13 +132,64 @@ Lastly we check the category relations of different article versions.
session.commit()
- article.versions[0].category.name = u'Some category'
- article.versions[1].category.name = u'Some other category'
+ article.versions[0].category.name # u'Some category'
+ article.versions[1].category.name # u'Some other category'
The logic how SQLAlchemy-Continuum builds these relationships is within the RelationshipBuilder class.
+Relationships to non-versioned classes
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+Let's take previous example of Articles and Categories. Now consider that only Article model is versioned:
+
+
+::
+
+
+ class Article(Base):
+ __tablename__ = 'article'
+ __versioned__ = {}
+
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ name = sa.Column(sa.Unicode(255), nullable=False)
+
+
+ class Category(Base):
+ __tablename__ = 'tag'
+
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ name = sa.Column(sa.Unicode(255))
+ article_id = sa.Column(sa.Integer, sa.ForeignKey(Article.id))
+ article = sa.orm.relationship(
+ Article,
+ backref=sa.orm.backref('categories')
+ )
+
+
+Here Article versions will still reflect the relationships of Article model but they will simply return Category objects instead of CategoryVersion objects:
+
+
+::
+
+
+ category = Category(name=u'Some category')
+ article = Article(
+ name=u'Some article',
+ category=category
+ )
+ session.add(article)
+ session.commit()
+
+ article.category = Category(name=u'Some other category')
+ session.commit()
+
+ version = article.versions[0]
+ version.category.name # u'Some other category'
+ isinstance(version.category, Category) # True
+
+
Dynamic relationships
^^^^^^^^^^^^^^^^^^^^^
@@ -153,9 +204,10 @@ If the parent class has a dynamic relationship it will be reflected as a propert
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255), nullable=False)
+
class Tag(Base):
__tablename__ = 'tag'
- __versioned__ = {}
+ __versioned__ = {}
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))
@@ -177,5 +229,5 @@ If the parent class has a dynamic relationship it will be reflected as a propert
tag_query = article.versions[0].tags
tag_query.all() # return all tags for given version
- tag_query.count() # return the tag count for given versoin
+ tag_query.count() # return the tag count for given version
diff --git a/setup.py b/setup.py
index 95a268f7..6dbcf5a6 100644
--- a/setup.py
+++ b/setup.py
@@ -5,23 +5,22 @@
Versioning and auditing extension for SQLAlchemy.
"""
-from setuptools import setup, Command
-import subprocess
+import os
import sys
+import re
+from setuptools import setup
-class PyTest(Command):
- user_options = []
+HERE = os.path.dirname(os.path.abspath(__file__))
+PY3 = sys.version_info[0] == 3
- def initialize_options(self):
- pass
- def finalize_options(self):
- pass
-
- def run(self):
- errno = subprocess.call(['py.test'])
- raise SystemExit(errno)
+def get_version():
+ filename = os.path.join(HERE, 'sqlalchemy_continuum', '__init__.py')
+ with open(filename) as f:
+ contents = f.read()
+ pattern = r"^__version__ = '(.*?)'$"
+ return re.search(pattern, contents, re.MULTILINE).group(1)
extras_require = {
@@ -35,7 +34,9 @@ def run(self):
'anyjson': ['anyjson>=0.3.3'],
'flask': ['Flask>=0.9'],
'flask-login': ['Flask-Login>=0.2.9'],
- 'i18n': ['SQLAlchemy-i18n >= 0.8.2'],
+ 'flask-sqlalchemy': ['Flask-SQLAlchemy>=1.0'],
+ 'flexmock': ['flexmock>=0.9.7'],
+ 'i18n': ['SQLAlchemy-i18n>=0.8.4'],
}
@@ -47,7 +48,7 @@ def run(self):
setup(
name='SQLAlchemy-Continuum',
- version='1.0-b1',
+ version=get_version(),
url='https://github.com/kvesteri/sqlalchemy-continuum',
license='BSD',
author='Konsta Vesterinen',
@@ -56,26 +57,29 @@ def run(self):
long_description=__doc__,
packages=[
'sqlalchemy_continuum',
- 'sqlalchemy_continuum.plugins'
+ 'sqlalchemy_continuum.plugins',
+ 'sqlalchemy_continuum.dialects'
],
zip_safe=False,
include_package_data=True,
platforms='any',
install_requires=[
- 'SQLAlchemy>=0.9.3',
- 'SQLAlchemy-Utils>=0.25.0',
- 'inflection>=0.2.0',
- 'ordereddict>=1.1'
- if sys.version_info[0] == 2 and sys.version_info[1] < 7 else ''
+ 'SQLAlchemy>=1.0.8',
+ 'SQLAlchemy-Utils>=0.30.12'
],
extras_require=extras_require,
- cmdclass={'test': PyTest},
classifiers=[
'Environment :: Web Environment',
'Intended Audience :: Developers',
'License :: OSI Approved :: BSD License',
'Operating System :: OS Independent',
'Programming Language :: Python',
+ 'Programming Language :: Python :: 2',
+ 'Programming Language :: Python :: 2.7',
+ 'Programming Language :: Python :: 3',
+ 'Programming Language :: Python :: 3.4',
+ 'Programming Language :: Python :: 3.5',
+ 'Programming Language :: Python :: 3.6',
'Topic :: Internet :: WWW/HTTP :: Dynamic Content',
'Topic :: Software Development :: Libraries :: Python Modules'
]
diff --git a/sqlalchemy_continuum/__init__.py b/sqlalchemy_continuum/__init__.py
index bf3805f1..8e212af1 100644
--- a/sqlalchemy_continuum/__init__.py
+++ b/sqlalchemy_continuum/__init__.py
@@ -1,18 +1,24 @@
import sqlalchemy as sa
+from .exc import ClassNotVersioned, ImproperlyConfigured
from .manager import VersioningManager
from .operation import Operation
+from .transaction import TransactionFactory
from .unit_of_work import UnitOfWork
from .utils import (
changeset,
+ count_versions,
get_versioning_manager,
+ is_modified,
+ is_session_modified,
parent_class,
transaction_class,
+ tx_column_name,
vacuum,
version_class,
)
-__version__ = '1.0-b1'
+__version__ = '1.3.6'
versioning_manager = VersioningManager()
@@ -23,7 +29,8 @@ def make_versioned(
session=sa.orm.session.Session,
manager=versioning_manager,
plugins=None,
- options=None
+ options=None,
+ user_cls='User'
):
"""
This is the public API function of SQLAlchemy-Continuum for making certain
@@ -36,24 +43,80 @@ def make_versioned(
SQLAlchemy session to apply the versioning to. By default this is
sa.orm.session.Session meaning it applies to all Session subclasses.
:param manager:
- The versioning manager. Override this if you want to use one of
- SQLAlchemy-Continuum's extensions (eg. Flask extension)
+ SQLAlchemy-Continuum versioning manager.
:param plugins:
Plugins to pass for versioning manager.
:param options:
A dictionary of VersioningManager options.
+ :param user_cls:
+ User class which the Transaction class should have relationship to.
+ This can either be a class or string name of a class for lazy
+ evaluation.
"""
if plugins is not None:
manager.plugins = plugins
+
+ if options is not None:
+ manager.options.update(options)
+
+ manager.user_cls = user_cls
manager.apply_class_configuration_listeners(mapper)
manager.track_operations(mapper)
manager.track_session(session)
- if options is not None:
- manager.options.update(options)
+ sa.event.listen(
+ sa.engine.Engine,
+ 'before_cursor_execute',
+ manager.track_association_operations
+ )
+
+ sa.event.listen(
+ sa.engine.Engine,
+ 'rollback',
+ manager.clear_connection
+ )
sa.event.listen(
+ sa.engine.Engine,
+ 'set_connection_execution_options',
+ manager.track_cloned_connections
+ )
+
+
+def remove_versioning(
+ mapper=sa.orm.mapper,
+ session=sa.orm.session.Session,
+ manager=versioning_manager
+):
+ """
+ Remove the versioning from given mapper / session and manager.
+
+ :param mapper:
+ SQLAlchemy mapper to remove the versioning from.
+ :param session:
+ SQLAlchemy session to remove the versioning from. By default this is
+ sa.orm.session.Session meaning it applies to all sessions.
+ :param manager:
+ SQLAlchemy-Continuum versioning manager.
+ """
+ manager.reset()
+ manager.remove_class_configuration_listeners(mapper)
+ manager.remove_operations_tracking(mapper)
+ manager.remove_session_tracking(session)
+ sa.event.remove(
sa.engine.Engine,
'before_cursor_execute',
manager.track_association_operations
)
+
+ sa.event.remove(
+ sa.engine.Engine,
+ 'rollback',
+ manager.clear_connection
+ )
+
+ sa.event.remove(
+ sa.engine.Engine,
+ 'set_connection_execution_options',
+ manager.track_cloned_connections
+ )
diff --git a/sqlalchemy_continuum/builder.py b/sqlalchemy_continuum/builder.py
index 28be06d8..47bbef35 100644
--- a/sqlalchemy_continuum/builder.py
+++ b/sqlalchemy_continuum/builder.py
@@ -1,14 +1,33 @@
from copy import copy
+from inspect import getmro
import sqlalchemy as sa
-from sqlalchemy_utils.functions import declarative_base
+from sqlalchemy_utils.functions import get_declarative_base
-from .table_builder import TableBuilder
+from .dialects.postgresql import create_versioning_trigger_listeners
from .model_builder import ModelBuilder
from .relationship_builder import RelationshipBuilder
+from .table_builder import TableBuilder
class Builder(object):
+ def build_triggers(self):
+ """
+ Build native database versioning triggers for all versioned models that
+ were collected during class instrumentation process.
+ """
+ processed_tables = set()
+ for cls in self.manager.pending_classes:
+ if not self.manager.option(cls, 'versioning'):
+ continue
+
+ if self.manager.option(cls, 'native_versioning'):
+ cls.__versioning_manager__ = self.manager
+
+ if cls.__table__ not in processed_tables:
+ create_versioning_trigger_listeners(self.manager, cls)
+ processed_tables.add(cls.__table__)
+
def build_tables(self):
"""
Build tables for version models based on classes that were collected
@@ -18,23 +37,24 @@ def build_tables(self):
if not self.manager.option(cls, 'versioning'):
continue
- inherited_table = None
- for class_ in self.manager.tables:
- if (issubclass(cls, class_) and
- cls.__table__ == class_.__table__):
- inherited_table = self.manager.tables[class_]
- break
-
- builder = TableBuilder(
- self.manager,
- cls.__table__,
- model=cls
- )
- if inherited_table is not None:
- self.manager.tables[class_] = builder(inherited_table)
- else:
- table = builder()
- self.manager.tables[cls] = table
+ if self.manager.option(cls, 'create_tables'):
+ inherited_table = None
+ for class_ in self.manager.tables:
+ if (issubclass(cls, class_) and
+ cls.__table__ == class_.__table__):
+ inherited_table = self.manager.tables[class_]
+ break
+
+ builder = TableBuilder(
+ self.manager,
+ cls.__table__,
+ model=cls
+ )
+ if inherited_table is not None:
+ self.manager.tables[class_] = builder(inherited_table)
+ else:
+ table = builder()
+ self.manager.tables[cls] = table
def closest_matching_table(self, model):
"""
@@ -46,9 +66,10 @@ def closest_matching_table(self, model):
"""
if model in self.manager.tables:
return self.manager.tables[model]
- for cls in self.manager.tables:
- if issubclass(model, cls):
- return self.manager.tables[cls]
+ subclasses = [cls for cls in self.manager.tables if issubclass(model, cls)]
+ ordered_subclasses = [cls for cls in getmro(model) if cls in subclasses]
+ return self.manager.tables[ordered_subclasses[0]] if ordered_subclasses else None
+
def build_models(self):
"""
@@ -56,11 +77,6 @@ def build_models(self):
during class instrumentation process.
"""
if self.manager.pending_classes:
- cls = self.manager.pending_classes[0]
- self.manager.declarative_base = declarative_base(cls)
- self.manager.create_transaction_model()
- self.manager.plugins.after_build_tx_class(self.manager)
-
for cls in self.manager.pending_classes:
if not self.manager.option(cls, 'versioning'):
continue
@@ -78,7 +94,7 @@ def build_models(self):
version_cls
)
- self.manager.plugins.after_build_models(self.manager)
+ self.manager.plugins.after_build_models(self.manager)
def build_relationships(self, version_classes):
"""
@@ -112,6 +128,19 @@ def instrument_versioned_classes(self, mapper, cls):
self.manager.pending_classes.append(cls)
self.manager.metadata = cls.metadata
+ if hasattr(cls, '__version_parent__'):
+ parent = cls.__version_parent__
+ self.manager.version_class_map[parent] = cls
+ self.manager.parent_class_map[cls] = parent
+ del cls.__version_parent__
+
+ def build_transaction_class(self):
+ if self.manager.pending_classes:
+ cls = self.manager.pending_classes[0]
+ self.manager.declarative_base = get_declarative_base(cls)
+ self.manager.create_transaction_model()
+ self.manager.plugins.after_build_tx_class(self.manager)
+
def configure_versioned_classes(self):
"""
Configures all versioned classes that were collected during
@@ -127,7 +156,14 @@ def configure_versioned_classes(self):
if not self.manager.options['versioning']:
return
+ self.build_triggers()
self.build_tables()
+ self.build_transaction_class()
+
+ if not self.manager.options['create_models']:
+ self.manager.pending_classes = []
+ return
+
self.build_models()
# Create copy of all pending versioned classes so that we can inspect
diff --git a/sqlalchemy_continuum/dialects/__init__.py b/sqlalchemy_continuum/dialects/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/sqlalchemy_continuum/dialects/postgresql.py b/sqlalchemy_continuum/dialects/postgresql.py
new file mode 100644
index 00000000..f24d9077
--- /dev/null
+++ b/sqlalchemy_continuum/dialects/postgresql.py
@@ -0,0 +1,526 @@
+import sqlalchemy as sa
+
+from sqlalchemy_continuum.plugins import PropertyModTrackerPlugin
+
+
+trigger_sql = """
+CREATE TRIGGER {trigger_name}
+AFTER INSERT OR UPDATE OR DELETE ON {table_name}
+FOR EACH ROW EXECUTE PROCEDURE {procedure_name}()
+"""
+
+upsert_cte_sql = """
+WITH upsert as
+(
+ UPDATE {version_table_name}
+ SET {update_values}
+ WHERE
+ {transaction_column} = transaction_id_value
+ AND
+ {primary_key_criteria}
+ RETURNING *
+)
+INSERT INTO {version_table_name}
+({transaction_column}, {operation_type_column}, {column_names})
+SELECT
+ transaction_id_value,
+ {operation_type},
+ {insert_values}
+WHERE NOT EXISTS (SELECT 1 FROM upsert);
+"""
+
+temporary_transaction_sql = """
+CREATE TEMP TABLE IF NOT EXISTS {temporary_transaction_table}
+({transaction_table_columns})
+ON COMMIT DELETE ROWS;
+"""
+
+insert_temporary_transaction_sql = """
+INSERT INTO {temporary_transaction_table} ({transaction_table_columns})
+VALUES ({transaction_values});
+"""
+
+temp_transaction_trigger_sql = """
+CREATE TRIGGER transaction_trigger
+AFTER INSERT ON {transaction_table}
+FOR EACH ROW EXECUTE PROCEDURE transaction_temp_table_generator()
+"""
+
+procedure_sql = """
+CREATE OR REPLACE FUNCTION {procedure_name}() RETURNS TRIGGER AS $$
+DECLARE transaction_id_value INT;
+BEGIN
+ BEGIN
+ transaction_id_value = (SELECT id FROM temporary_transaction);
+ EXCEPTION WHEN others THEN
+ RETURN NEW;
+ END;
+ IF transaction_id_value IS NULL THEN
+ RETURN NEW;
+ END IF;
+
+ IF (TG_OP = 'INSERT') THEN
+ {after_insert}
+ {upsert_insert}
+ ELSIF (TG_OP = 'UPDATE') THEN
+ IF (hstore(NEW.*) - hstore(OLD.*) - ARRAY[{excluded_columns}]::text[])
+ = hstore('')
+ THEN
+ RETURN NULL;
+ END IF;
+ {after_update}
+ {upsert_update}
+ ELSIF (TG_OP = 'DELETE') THEN
+ {after_delete}
+ {upsert_delete}
+ END IF;
+ RETURN NEW;
+END;
+$$
+LANGUAGE plpgsql
+"""
+
+validity_sql = """
+UPDATE {version_table_name}
+SET {end_transaction_column} = transaction_id_value
+WHERE
+ {transaction_column} = (
+ SELECT MIN({transaction_column}) FROM {version_table_name}
+ WHERE {end_transaction_column} IS NULL AND {primary_key_criteria}
+ ) AND
+ {primary_key_criteria};
+"""
+
+
+def uses_property_mod_tracking(manager):
+ return any(
+ isinstance(plugin, PropertyModTrackerPlugin)
+ for plugin in manager.plugins
+ )
+
+
+class SQLConstruct(object):
+ def __init__(
+ self,
+ table,
+ transaction_column_name,
+ operation_type_column_name,
+ version_table_name_format,
+ excluded_columns=None,
+ update_validity_for_tables=None,
+ use_property_mod_tracking=False,
+ end_transaction_column_name=None,
+ ):
+ self.update_validity_for_tables = update_validity_for_tables
+ self.operation_type_column_name = operation_type_column_name
+ self.transaction_column_name = transaction_column_name
+ self.end_transaction_column_name = end_transaction_column_name
+ self.version_table_name_format = version_table_name_format
+ self.use_property_mod_tracking = use_property_mod_tracking
+ self.table = table
+ self.excluded_columns = excluded_columns
+ if update_validity_for_tables is None:
+ self.update_validity_for_tables = []
+ if self.excluded_columns is None:
+ self.excluded_columns = []
+
+ @property
+ def table_name(self):
+ if self.table.schema:
+ return '%s."%s"' % (self.table.schema, self.table.name)
+ else:
+ return '"' + self.table.name + '"'
+
+ @property
+ def transaction_table_name(self):
+ if self.table.schema:
+ return '%s.transaction' % self.table.schema
+ else:
+ return 'transaction'
+
+ @property
+ def temporary_transaction_table_name(self):
+ return 'temporary_transaction'
+
+ @property
+ def version_table_name(self):
+ version_table_name = self.version_table_name_format % self.table.name
+ if self.table.schema:
+ version_table_name = '%s.%s' % (
+ self.table.schema, version_table_name
+ )
+ return version_table_name
+
+ @classmethod
+ def for_manager(self, manager, cls):
+ strategy = manager.option(cls, 'strategy')
+ operation_type_column = manager.option(
+ cls,
+ 'operation_type_column_name'
+ )
+ excluded_columns = [
+ c.name for c in sa.inspect(cls).columns
+ if manager.is_excluded_column(cls, c)
+ ]
+ return self(
+ update_validity_for_tables=(
+ sa.inspect(cls).tables if strategy == 'validity' else []
+ ),
+ version_table_name_format=manager.option(cls, 'table_name'),
+ operation_type_column_name=operation_type_column,
+ transaction_column_name=manager.option(
+ cls, 'transaction_column_name'
+ ),
+ end_transaction_column_name=manager.option(
+ cls, 'end_transaction_column_name'
+ ),
+ use_property_mod_tracking=uses_property_mod_tracking(manager),
+ excluded_columns=excluded_columns,
+ table=cls.__table__
+ )
+
+ @property
+ def columns(self):
+ return [c for c in self.table.c if c.name not in self.excluded_columns]
+
+ @property
+ def columns_without_pks(self):
+ return [c for c in self.columns if not c.primary_key]
+
+ @property
+ def pk_columns(self):
+ return [c for c in self.columns if c.primary_key]
+
+ def copy_args(self):
+ return dict(
+ (k, v) for k, v in self.__dict__.items() if not k.startswith('__')
+ )
+
+
+class UpsertSQL(SQLConstruct):
+ builders = {
+ 'update_values': ', ',
+ 'insert_values': ', ',
+ 'column_names': ', ',
+ 'primary_key_criteria': ' AND ',
+ }
+
+ def __init__(self, *args, **kwargs):
+ SQLConstruct.__init__(self, *args, **kwargs)
+
+ for key in self.builders:
+ setattr(self, key, getattr(self, 'build_%s' % key)())
+
+ def build_column_names(self):
+ column_names = ['"%s"' % c.name for c in self.columns]
+ if self.use_property_mod_tracking:
+ column_names += [
+ '%s_mod' % c.name for c in self.columns_without_pks
+ ]
+ return column_names
+
+ def build_primary_key_criteria(self):
+ return [
+ '"{name}" = NEW."{name}"'.format(name=c.name)
+ for c in self.columns if c.primary_key
+ ]
+
+ def build_update_values(self):
+ parent_columns = [
+ '"{name}" = NEW."{name}"'.format(name=c.name)
+ for c in self.columns
+ ]
+ mod_columns = []
+ if self.use_property_mod_tracking:
+ mod_columns = [
+ '{0}_mod = {0}_mod OR OLD."{0}" IS DISTINCT FROM NEW."{0}"'
+ .format(c.name)
+ for c in self.columns_without_pks
+ ]
+
+ return (
+ ['%s = 1' % self.operation_type_column_name] +
+ parent_columns +
+ mod_columns
+ )
+
+ def build_insert_values(self):
+ values = self.build_values()
+ if self.use_property_mod_tracking:
+ values += self.build_mod_tracking_values()
+ return values
+
+ def build_values(self):
+ return ['NEW."%s"' % c.name for c in self.columns]
+
+ def build_mod_tracking_values(self):
+ return []
+
+ def __str__(self):
+ params = dict(
+ version_table_name=self.version_table_name,
+ transaction_column=self.transaction_column_name,
+ operation_type=self.operation_type,
+ operation_type_column=self.operation_type_column_name,
+ transaction_table_name=self.transaction_table_name,
+ )
+ for key, join_operator in self.builders.items():
+ params[key] = join_operator.join(getattr(self, key))
+
+ sql = upsert_cte_sql.format(**params)
+ return sql
+
+
+class DeleteUpsertSQL(UpsertSQL):
+ operation_type = 2
+
+ def build_primary_key_criteria(self):
+ return [
+ '"{name}" = OLD."{name}"'.format(name=c.name)
+ for c in self.pk_columns
+ ]
+
+ def build_mod_tracking_values(self):
+ return ['True'] * len(self.columns_without_pks)
+
+ def build_update_values(self):
+ return [
+ '"{name}" = OLD."{name}"'.format(name=c.name)
+ for c in self.columns
+ ]
+
+ def build_values(self):
+ return ['OLD."%s"' % c.name for c in self.columns]
+
+
+class InsertUpsertSQL(UpsertSQL):
+ operation_type = 0
+
+ def build_mod_tracking_values(self):
+ return ['True'] * len(self.columns_without_pks)
+
+
+class UpdateUpsertSQL(UpsertSQL):
+ operation_type = 1
+
+ def build_mod_tracking_values(self):
+ return [
+ 'OLD."{0}" IS DISTINCT FROM NEW."{0}"'
+ .format(c.name) for c in self.columns_without_pks
+ ]
+
+
+class ValiditySQL(SQLConstruct):
+ @property
+ def primary_key_criteria(self):
+ return ' AND '.join(
+ '"{name}" = NEW."{name}"'.format(name=c.name)
+ for c in self.pk_columns
+ )
+
+ def __str__(self):
+ params = dict(
+ version_table_name=self.version_table_name,
+ transaction_table_name=self.transaction_table_name,
+ transaction_column=self.transaction_column_name,
+ end_transaction_column=self.end_transaction_column_name,
+ primary_key_criteria=self.primary_key_criteria
+ )
+ return validity_sql.format(**params)
+
+
+class InsertValiditySQL(ValiditySQL):
+ pass
+
+
+class UpdateValiditySQL(ValiditySQL):
+ pass
+
+
+class DeleteValiditySQL(ValiditySQL):
+ @property
+ def primary_key_criteria(self):
+ return ' AND '.join(
+ '{name} = OLD."{name}"'.format(name=c.name)
+ for c in self.pk_columns
+ )
+
+
+def get_validity_sql(class_, tables, params):
+ params = params.copy()
+ del params['table']
+ return ''.join(str(class_(table, **params)) for table in tables)
+
+
+class CreateTriggerSQL(SQLConstruct):
+ def __str__(self):
+ return trigger_sql.format(
+ trigger_name='%s_trigger' % self.table.name,
+ table_name=self.table_name,
+ procedure_name='%s_audit' % self.table.name
+ )
+
+
+class TransactionSQLConstruct(object):
+ def __init__(self, **kwargs):
+ self.__dict__.update(kwargs)
+
+
+class CreateTemporaryTransactionTableSQL(TransactionSQLConstruct):
+ table_name = 'temporary_transaction'
+
+ def __str__(self):
+ return temporary_transaction_sql.format(
+ temporary_transaction_table=self.table_name,
+ transaction_table_columns='id BIGINT, PRIMARY KEY(id)'
+ )
+
+
+class InsertTemporaryTransactionSQL(TransactionSQLConstruct):
+ table_name = 'temporary_transaction'
+ transaction_values = 'transaction_id_value'
+
+ def __str__(self):
+ return insert_temporary_transaction_sql.format(
+ temporary_transaction_table=self.table_name,
+ transaction_table_columns='id',
+ transaction_values=self.transaction_values
+ )
+
+
+class CreateTriggerFunctionSQL(SQLConstruct):
+ def __str__(self):
+ args = self.copy_args()
+ tables = self.update_validity_for_tables
+ after_insert = get_validity_sql(InsertValiditySQL, tables, args)
+ after_update = get_validity_sql(UpdateValiditySQL, tables, args)
+ after_delete = get_validity_sql(DeleteValiditySQL, tables, args)
+
+ sql = procedure_sql.format(
+ procedure_name='%s_audit' % self.table.name,
+ excluded_columns=', '.join(
+ "'%s'" % c for c in self.excluded_columns
+ ),
+ transaction_table_name=self.transaction_table_name,
+ after_insert=after_insert,
+ after_update=after_update,
+ after_delete=after_delete,
+ temporary_transaction_sql=(
+ CreateTemporaryTransactionTableSQL()
+ ),
+ insert_temporary_transaction_sql=(
+ InsertTemporaryTransactionSQL()
+ ),
+ upsert_insert=InsertUpsertSQL(**args),
+ upsert_update=UpdateUpsertSQL(**args),
+ upsert_delete=DeleteUpsertSQL(**args)
+ )
+ return sql
+
+
+class TransactionTriggerSQL(object):
+ def __init__(self, tx_class):
+ self.table = tx_class.__table__
+
+ @property
+ def transaction_table_name(self):
+ if self.table.schema:
+ return '%s.transaction' % self.table.schema
+ else:
+ return 'transaction'
+
+ def __str__(self):
+ return temp_transaction_trigger_sql.format(
+ transaction_table=self.transaction_table_name
+ )
+
+
+def create_versioning_trigger_listeners(manager, cls):
+ sa.event.listen(
+ cls.__table__,
+ 'after_create',
+ sa.schema.DDL(str(CreateTriggerFunctionSQL.for_manager(manager, cls)))
+ )
+ sa.event.listen(
+ cls.__table__,
+ 'after_create',
+ sa.schema.DDL(str(CreateTriggerSQL.for_manager(manager, cls)))
+ )
+ sa.event.listen(
+ cls.__table__,
+ 'after_drop',
+ sa.schema.DDL(
+ 'DROP FUNCTION IF EXISTS %s()' %
+ '%s_audit' % cls.__table__.name,
+ )
+ )
+
+
+def sync_trigger(conn, table_name):
+ """
+ Synchronizes versioning trigger for given table with given connection.
+
+ ::
+
+
+ sync_trigger(conn, 'my_table')
+
+
+ :param conn: SQLAlchemy connection object
+ :param table_name: Name of the table to synchronize versioning trigger for
+
+ .. versionadded: 1.1.0
+ """
+ meta = sa.MetaData()
+ version_table = sa.Table(
+ table_name,
+ meta,
+ autoload=True,
+ autoload_with=conn
+ )
+ parent_table = sa.Table(
+ table_name[0:-len('_version')],
+ meta,
+ autoload=True,
+ autoload_with=conn
+ )
+ excluded_columns = (
+ set(c.name for c in parent_table.c) -
+ set(c.name for c in version_table.c if not c.name.endswith('_mod'))
+ )
+ drop_trigger(conn, parent_table.name)
+ create_trigger(conn, table=parent_table, excluded_columns=excluded_columns)
+
+
+def create_trigger(
+ conn,
+ table,
+ transaction_column_name='transaction_id',
+ operation_type_column_name='operation_type',
+ version_table_name_format='%s_version',
+ excluded_columns=None,
+ use_property_mod_tracking=True,
+ end_transaction_column_name=None,
+):
+ params = dict(
+ table=table,
+ update_validity_for_tables=[],
+ transaction_column_name=transaction_column_name,
+ operation_type_column_name=operation_type_column_name,
+ version_table_name_format=version_table_name_format,
+ excluded_columns=excluded_columns,
+ use_property_mod_tracking=use_property_mod_tracking,
+ end_transaction_column_name=end_transaction_column_name,
+ )
+ conn.execute(str(CreateTriggerFunctionSQL(**params)))
+ conn.execute(str(CreateTriggerSQL(**params)))
+
+
+def drop_trigger(conn, table_name):
+ conn.execute(
+ 'DROP TRIGGER IF EXISTS %s_trigger ON "%s"' % (
+ table_name,
+ table_name
+ )
+ )
+ conn.execute('DROP FUNCTION IF EXISTS %s_audit()' % table_name)
diff --git a/sqlalchemy_continuum/exc.py b/sqlalchemy_continuum/exc.py
new file mode 100644
index 00000000..9bd568ea
--- /dev/null
+++ b/sqlalchemy_continuum/exc.py
@@ -0,0 +1,10 @@
+class VersioningError(Exception):
+ pass
+
+
+class ClassNotVersioned(VersioningError):
+ pass
+
+
+class ImproperlyConfigured(VersioningError):
+ pass
diff --git a/sqlalchemy_continuum/expression_reflector.py b/sqlalchemy_continuum/expression_reflector.py
index 474966c7..fafb14f6 100644
--- a/sqlalchemy_continuum/expression_reflector.py
+++ b/sqlalchemy_continuum/expression_reflector.py
@@ -1,74 +1,33 @@
-import six
import sqlalchemy as sa
-from sqlalchemy.sql.expression import (
- BooleanClauseList,
- BinaryExpression,
- BindParameter
-)
+from sqlalchemy.sql.expression import bindparam
+
from .utils import version_table
-class ExpressionReflector(object):
- parent = None
- parent_class = None
+class VersionExpressionReflector(sa.sql.visitors.ReplacingCloningVisitor):
+ def __init__(self, parent, relationship):
+ self.parent = parent
+ self.relationship = relationship
- def expression(self, expression):
- """
- Parses SQLAlchemy expression
- """
- if expression is None:
+ def replace(self, column):
+ if not isinstance(column, sa.Column):
return
- if isinstance(expression, BinaryExpression):
- return self.binary_expression(expression)
- elif isinstance(expression, BooleanClauseList):
- return self.boolean_expression(expression)
-
- def parameter(self, parameter):
- """
- Parses SQLAlchemy BindParameter
- """
- if isinstance(parameter, sa.Column):
- table = version_table(parameter.table)
- if self.parent and table == self.parent.__table__:
- return getattr(self.parent, parameter.name)
- else:
- return table.c[parameter.name]
- elif isinstance(parameter, BindParameter):
- # somehow bind parameters passed as unicode are converted to
- # ascii strings along the way, force convert them back to avoid
- # sqlalchemy unicode warnings
- if isinstance(parameter.type, sa.Unicode):
- parameter.value = six.text_type(parameter.value)
- return parameter
-
- def binary_expression(self, expression):
- """
- Parses SQLAlchemy BinaryExpression
- """
- return expression.operator(
- self.parameter(expression.left),
- self.parameter(expression.right)
- )
-
- def boolean_expression(self, expression):
- """
- Parses SQLAlchemy BooleanExpression
- """
- return expression.operator(*[
- self.expression(child_expr)
- for child_expr in expression.get_children()
- ])
-
- def __call__(self, expression):
- return self.expression(expression)
-
-
-class ClassExpressionReflector(ExpressionReflector):
- def __init__(self, parent_class):
- self.parent_class = parent_class
-
-
-class ObjectExpressionReflector(ExpressionReflector):
- def __init__(self, parent):
- self.parent = parent
- self.parent_class = parent.__class__
+ try:
+ table = version_table(column.table)
+ except KeyError:
+ reflected_column = column
+ else:
+ reflected_column = table.c[column.name]
+ if (
+ column in self.relationship.local_columns and
+ table == self.parent.__table__
+ ):
+ reflected_column = bindparam(
+ column.key,
+ getattr(self.parent, column.key)
+ )
+
+ return reflected_column
+
+ def __call__(self, expr):
+ return self.traverse(expr)
diff --git a/sqlalchemy_continuum/fetcher.py b/sqlalchemy_continuum/fetcher.py
index 529a4652..1ac1a175 100644
--- a/sqlalchemy_continuum/fetcher.py
+++ b/sqlalchemy_continuum/fetcher.py
@@ -1,21 +1,28 @@
import operator
import sqlalchemy as sa
-from sqlalchemy_utils import primary_keys, identity
+from sqlalchemy_utils import get_primary_keys, identity
from .utils import tx_column_name, end_tx_column_name
-def eq(tuple_):
- return tuple_[0] == tuple_[1]
-
-
def parent_identity(obj_or_class):
return tuple(
- getattr(obj_or_class, column.name)
- for column in primary_keys(obj_or_class)
- if column.name != tx_column_name(obj_or_class)
+ getattr(obj_or_class, column_key)
+ for column_key in get_primary_keys(obj_or_class).keys()
+ if column_key != tx_column_name(obj_or_class)
)
+def eqmap(callback, iterable):
+ for a, b in zip(*map(callback, iterable)):
+ yield a == b
+
+
+def parent_criteria(obj, class_=None):
+ if class_ is None:
+ class_ = obj.__class__
+ return eqmap(parent_identity, (class_, obj))
+
+
class VersionObjectFetcher(object):
def __init__(self, manager):
self.manager = manager
@@ -43,16 +50,7 @@ def next(self, obj):
"""
return self.next_query(obj).first()
- def parent_identity_correlation(self, obj):
- return map(
- eq,
- zip(
- parent_identity(obj.__class__),
- parent_identity(obj)
- )
- )
-
- def _transaction_id_subquery(self, obj, next_or_prev='next'):
+ def _transaction_id_subquery(self, obj, next_or_prev='next', alias=None):
if next_or_prev == 'next':
op = operator.gt
func = sa.func.min
@@ -60,24 +58,37 @@ def _transaction_id_subquery(self, obj, next_or_prev='next'):
op = operator.lt
func = sa.func.max
- alias = sa.orm.aliased(obj)
+ if alias is None:
+ alias = sa.orm.aliased(obj)
+ table = alias.__table__
+ if hasattr(alias, 'c'):
+ attrs = alias.c
+ else:
+ attrs = alias
+ else:
+ table = alias.original
+ attrs = alias.c
query = (
sa.select(
[func(
- getattr(alias, tx_column_name(obj))
+ getattr(attrs, tx_column_name(obj))
)],
- from_obj=[alias.__table__]
+ from_obj=[table]
)
.where(
sa.and_(
op(
- getattr(alias, tx_column_name(obj)),
+ getattr(attrs, tx_column_name(obj)),
getattr(obj, tx_column_name(obj))
),
- *map(eq, zip(parent_identity(alias), parent_identity(obj)))
+ *[
+ getattr(attrs, pk) == getattr(obj, pk)
+ for pk in get_primary_keys(obj.__class__)
+ if pk != tx_column_name(obj)
+ ]
)
)
- .correlate(alias.__table__)
+ .correlate(table)
)
return query
@@ -96,7 +107,7 @@ def _next_prev_query(self, obj, next_or_prev='next'):
self._transaction_id_subquery(
obj, next_or_prev=next_or_prev
),
- *self.parent_identity_correlation(obj)
+ *parent_criteria(obj)
)
)
)
@@ -121,9 +132,7 @@ def _index_query(self, obj):
query = (
sa.select([subquery], from_obj=[obj.__table__])
.where(
- sa.and_(
- *map(eq, zip(identity(obj.__class__), identity(obj)))
- )
+ sa.and_(*eqmap(identity, (obj.__class__, obj)))
)
.order_by(
getattr(obj.__class__, tx_column_name(obj))
@@ -163,7 +172,7 @@ def next_query(self, obj):
getattr(obj.__class__, tx_column_name(obj))
==
getattr(obj, end_tx_column_name(obj)),
- *self.parent_identity_correlation(obj)
+ *parent_criteria(obj)
)
)
)
@@ -182,7 +191,7 @@ def previous_query(self, obj):
getattr(obj.__class__, end_tx_column_name(obj))
==
getattr(obj, tx_column_name(obj)),
- *self.parent_identity_correlation(obj)
+ *parent_criteria(obj)
)
)
)
diff --git a/sqlalchemy_continuum/manager.py b/sqlalchemy_continuum/manager.py
index da6ad6c9..191baca7 100644
--- a/sqlalchemy_continuum/manager.py
+++ b/sqlalchemy_continuum/manager.py
@@ -1,11 +1,9 @@
import re
from functools import wraps
-from inflection import underscore, pluralize
import sqlalchemy as sa
from sqlalchemy.orm import object_session
-from sqlalchemy_utils.functions import is_auto_assigned_date_column
-from sqlalchemy_utils.types import TSVectorType
+from sqlalchemy_utils import get_column_key
from .builder import Builder
from .fetcher import SubqueryFetcher, ValidityFetcher
@@ -13,7 +11,7 @@
from .plugins import PluginCollection
from .transaction import TransactionFactory
from .unit_of_work import UnitOfWork
-from .utils import get_bind, is_modified, is_versioned
+from .utils import is_modified, is_versioned
def tracked_operation(func):
@@ -23,8 +21,20 @@ def wrapper(self, mapper, connection, target):
return
session = object_session(target)
conn = session.connection()
- uow = self.units_of_work[conn]
+ try:
+ uow = self.units_of_work[conn]
+ except KeyError:
+ try:
+ uow = self.units_of_work[conn.engine]
+ except KeyError:
+ for connection in self.units_of_work.keys():
+ if not connection.closed and connection.connection is conn.connection:
+ uow = self.unit_of_work(session)
+ break # The ConnectionFairy is the same, this connection is a clone
+ else:
+ raise
return func(self, uow, target)
+
return wrapper
@@ -33,14 +43,33 @@ class VersioningManager(object):
VersioningManager delegates versioning configuration operations to builder
classes and the actual versioning to UnitOfWork class. Manager contains
configuration options that act as defaults for all versioned classes.
+
+ :param unit_of_work_cls:
+ The UnitOfWork class to use for initializing UnitOfWork objects for
+ versioning
+ :param transaction_cls:
+ Transaction class to use for versioning. If None, the default
+ Transaction class generated by TransactionFactory will be used.
+ :param user_cls:
+ User class which Transaction class should have relationship to. This
+ can either be a class or string name of a class for lazy evaluation.
+ :param options:
+ Versioning options
+ :param plugins:
+ Versioning plugins that listen the events invoked by the manager.
+ :param builder:
+ Builder object which handles the building of versioning tables and
+ models.
"""
+
def __init__(
- self,
- unit_of_work_cls=UnitOfWork,
- transaction_cls=None,
- options={},
- plugins=None,
- builder=None
+ self,
+ unit_of_work_cls=UnitOfWork,
+ transaction_cls=None,
+ user_cls=None,
+ options={},
+ plugins=None,
+ builder=None
):
self.uow_class = unit_of_work_cls
if builder is None:
@@ -51,6 +80,10 @@ def __init__(
self.reset()
if transaction_cls is not None:
self.transaction_cls = transaction_cls
+ else:
+ self.transaction_cls = TransactionFactory()
+ if user_cls is not None:
+ self.user_cls = user_cls
self.options = {
'versioning': True,
@@ -58,11 +91,16 @@ def __init__(
'table_name': '%s_version',
'exclude': [],
'include': [],
+ 'native_versioning': False,
+ 'create_models': True,
+ 'create_tables': True,
+ 'transaction_table_name': 'transaction',
+ 'transaction_table_schema_name': None,
'transaction_column_name': 'transaction_id',
'end_transaction_column_name': 'end_transaction_id',
'operation_type_column_name': 'operation_type',
- 'relation_naming_function': lambda a: pluralize(underscore(a)),
- 'strategy': 'validity'
+ 'strategy': 'validity',
+ 'use_module_name': False
}
if plugins is None:
self.plugins = []
@@ -70,10 +108,6 @@ def __init__(
self.plugins = plugins
self.options.update(options)
- # A dictionary of units of work. Keys as connection objects and values
- # as UnitOfWork objects.
- self.units_of_work = {}
-
@property
def plugins(self):
return self._plugins
@@ -98,12 +132,32 @@ def reset(self):
"""
self.tables = {}
self.pending_classes = []
- self.association_tables = set([])
- self.association_version_tables = set([])
+ self.association_tables = set()
+ self.association_version_tables = set()
self.declarative_base = None
- self.transaction_cls = TransactionFactory()
self.version_class_map = {}
self.parent_class_map = {}
+ self.session_listeners = {
+ 'before_flush': self.before_flush,
+ 'after_flush': self.after_flush,
+ 'after_commit': self.clear,
+ 'after_rollback': self.clear,
+ }
+ self.mapper_listeners = {
+ 'after_delete': self.track_deletes,
+ 'after_update': self.track_updates,
+ 'after_insert': self.track_inserts,
+ }
+ self.class_config_listeners = {
+ 'instrument_class': self.builder.instrument_versioned_classes,
+ 'after_configured': self.builder.configure_versioned_classes,
+ }
+
+ # A dictionary of units of work. Keys as connection objects and values
+ # as UnitOfWork objects.
+ self.units_of_work = {}
+
+ self.session_connection_map = {}
self.metadata = None
@@ -117,22 +171,24 @@ def create_transaction_model(self):
return self.transaction_cls
def is_excluded_column(self, model, column):
+ try:
+ key = get_column_key(model, column)
+ except sa.orm.exc.UnmappedColumnError:
+ return False
+
+ return self.is_excluded_property(model, key)
+
+ def is_excluded_property(self, model, key):
"""
- Returns whether or not given column of given model is excluded from
+ Returns whether or not given property of given model is excluded from
the associated history model.
:param model: SQLAlchemy declarative model object.
- :param column: SQLAlchemy Column object.
+ :param key: Model property key
"""
- if column.name in self.option(model, 'include'):
+ if key in self.option(model, 'include'):
return False
- return (
- column.name in self.option(model, 'exclude')
- or
- is_auto_assigned_date_column(column)
- or
- isinstance(column.type, TSVectorType)
- )
+ return key in self.option(model, 'exclude')
def option(self, model, name):
"""
@@ -169,16 +225,18 @@ def apply_class_configuration_listeners(self, mapper):
:param mapper:
SQLAlchemy mapper to apply the class configuration listeners to
"""
- sa.event.listen(
- mapper,
- 'instrument_class',
- self.builder.instrument_versioned_classes
- )
- sa.event.listen(
- mapper,
- 'after_configured',
- self.builder.configure_versioned_classes
- )
+ for event_name, listener in self.class_config_listeners.items():
+ sa.event.listen(mapper, event_name, listener)
+
+ def remove_class_configuration_listeners(self, mapper):
+ """
+ Remove versioning class configuration listeners from specified mapper.
+
+ :param mapper:
+ mapper to remove class configuration listeners from
+ """
+ for event_name, listener in self.class_config_listeners.items():
+ sa.event.remove(mapper, event_name, listener)
def track_operations(self, mapper):
"""
@@ -187,36 +245,42 @@ def track_operations(self, mapper):
:param mapper: mapper to track the SQL operations from
"""
- sa.event.listen(
- mapper, 'after_delete', self.track_deletes
- )
- sa.event.listen(
- mapper, 'after_update', self.track_updates
- )
- sa.event.listen(
- mapper, 'after_insert', self.track_inserts
- )
+ for event_name, listener in self.mapper_listeners.items():
+ sa.event.listen(mapper, event_name, listener)
+
+ def remove_operations_tracking(self, mapper):
+ """
+ Remove listeners from specified mapper that track SQL inserts, updates
+ and deletes.
+
+ :param mapper:
+ mapper to remove the SQL operations tracking listeners from
+ """
+ for event_name, listener in self.mapper_listeners.items():
+ sa.event.remove(mapper, event_name, listener)
def track_session(self, session):
"""
Attach listeners that track the operations (flushing, committing and
rolling back) of given session. This method should be used in
- conjuction with `track_operations`.
+ conjunction with `track_operations`.
:param session: SQLAlchemy session to track the operations from
"""
- sa.event.listen(
- session, 'before_flush', self.before_flush
- )
- sa.event.listen(
- session, 'after_flush', self.after_flush
- )
- sa.event.listen(
- session, 'after_commit', self.clear
- )
- sa.event.listen(
- session, 'after_rollback', self.clear
- )
+ for event_name, listener in self.session_listeners.items():
+ sa.event.listen(session, event_name, listener)
+
+ def remove_session_tracking(self, session):
+ """
+ Remove listeners that track the operations (flushing, committing and
+ rolling back) of given session. This method should be used in
+ conjunction with `remove_operations_tracking`.
+
+ :param session:
+ SQLAlchemy session to remove the operations tracking from
+ """
+ for event_name, listener in self.session_listeners.items():
+ sa.event.remove(session, event_name, listener)
@tracked_operation
def track_inserts(self, uow, target):
@@ -224,8 +288,6 @@ def track_inserts(self, uow, target):
Track object insert operations. Whenever object is inserted it is
added to this UnitOfWork's internal operations dictionary.
"""
- if not is_modified(target):
- return
uow.operations.add_insert(target)
@tracked_operation
@@ -246,19 +308,19 @@ def track_deletes(self, uow, target):
"""
uow.operations.add_delete(target)
- def unit_of_work(self, obj):
+ def unit_of_work(self, session):
"""
Return the associated SQLAlchemy-Continuum UnitOfWork object for given
- SQLAlchemy connection or session or declarative model object.
+ SQLAlchemy session object.
If no UnitOfWork object exists for given object then this method tries
to create one.
- :param obj:
- Either a SQLAlchemy declarative model object or SQLAlchemy
- connection object or SQLAlchemy session object
+ :param session: SQLAlchemy session object
"""
- conn = get_bind(obj)
+ conn = session.connection()
+ if conn not in self.session_connection_map.values():
+ self.session_connection_map[session] = conn
if conn in self.units_of_work:
return self.units_of_work[conn]
@@ -268,6 +330,13 @@ def unit_of_work(self, obj):
return uow
def before_flush(self, session, flush_context, instances):
+ """
+ Before flush listener for SQLAlchemy sessions. If this manager has
+ versioning enabled this listener invokes the process before flush of
+ associated UnitOfWork object.
+
+ :param session: SQLAlchemy session
+ """
if not self.options['versioning']:
return
@@ -290,16 +359,45 @@ def after_flush(self, session, flush_context):
def clear(self, session):
"""
- Simple SQLAlchemy listener that is being invoked after succesful
+ Simple SQLAlchemy listener that is being invoked after successful
transaction commit or when transaction rollback occurs. The purpose of
this listener is to reset this UnitOfWork back to its initialization
state.
:param session: SQLAlchemy session object
"""
- conn = session.bind
- uow = self.units_of_work[conn]
- uow.reset()
+ if session.transaction.nested:
+ return
+ conn = self.session_connection_map.pop(session, None)
+ if conn is None:
+ return
+
+ if conn in self.units_of_work:
+ uow = self.units_of_work[conn]
+ uow.reset(session)
+ del self.units_of_work[conn]
+
+ for connection in dict(self.units_of_work).keys():
+ if connection.closed or conn.connection is connection.connection:
+ uow = self.units_of_work[connection]
+ uow.reset(session)
+ del self.units_of_work[connection]
+
+ def clear_connection(self, conn):
+ if conn in self.units_of_work:
+ uow = self.units_of_work[conn]
+ uow.reset()
+ del self.units_of_work[conn]
+
+ for session, connection in dict(self.session_connection_map).items():
+ if connection is conn:
+ del self.session_connection_map[session]
+
+ for connection in dict(self.units_of_work).keys():
+ if connection.closed or conn.connection is connection.connection:
+ uow = self.units_of_work[connection]
+ uow.reset()
+ del self.units_of_work[connection]
def append_association_operation(self, conn, table_name, params, op):
"""
@@ -308,24 +406,46 @@ def append_association_operation(self, conn, table_name, params, op):
params['operation_type'] = op
stmt = (
self.metadata.tables[self.options['table_name'] % table_name]
- .insert()
- .values(params)
+ .insert()
+ .values(params)
)
- uow = self.units_of_work[conn]
+ try:
+ uow = self.units_of_work[conn]
+ except KeyError:
+ try:
+ uow = self.units_of_work[conn.engine]
+ except KeyError:
+ for connection in self.units_of_work.keys():
+ if not connection.closed and connection.connection is conn.connection:
+ uow = self.unit_of_work(conn.session)
+ break # The ConnectionFairy is the same, this connection is a clone
+ else:
+ raise
uow.pending_statements.append(stmt)
+ def track_cloned_connections(self, c, opt):
+ """
+ Track cloned connections from association tables.
+ """
+ if c not in self.units_of_work.keys():
+ for connection, uow in dict(self.units_of_work).items():
+ if not connection.closed and connection.connection is c.connection: # ConnectionFairy is the same - this is a clone
+ self.units_of_work[c] = uow
+
def track_association_operations(
- self, conn, cursor, statement, parameters, context, executemany
+ self, conn, cursor, statement, parameters, context, executemany
):
"""
Track association operations and adds the generated history
association operations to pending_statements list.
"""
- if not self.options['versioning']:
+ if (
+ not self.options['versioning'] and
+ not self.options['native_versioning']
+ ):
return
op = None
-
if context.isinsert:
op = Operation.INSERT
elif context.isdelete:
@@ -334,7 +454,8 @@ def track_association_operations(
if op is not None:
table_name = statement.split(' ')[2]
table_names = [
- table.name for table in self.association_tables
+ table.name if not table.schema else table.schema + '.' + table.name
+ for table in self.association_tables
]
if table_name in table_names:
if executemany:
diff --git a/sqlalchemy_continuum/model_builder.py b/sqlalchemy_continuum/model_builder.py
index 3a9609d7..2be6e63b 100644
--- a/sqlalchemy_continuum/model_builder.py
+++ b/sqlalchemy_continuum/model_builder.py
@@ -1,11 +1,88 @@
from copy import copy
+import six
import sqlalchemy as sa
-from sqlalchemy_utils.functions import primary_keys, declarative_base
-from .expression_reflector import ClassExpressionReflector
-from .utils import option
+from sqlalchemy.ext.declarative import declared_attr
+from sqlalchemy.orm import column_property
+from sqlalchemy_utils.functions import get_declarative_base
+
+from .utils import adapt_columns, option
from .version import VersionClassBase
+def find_closest_versioned_parent(manager, model):
+ """
+ Finds the closest versioned parent for current parent model.
+ """
+ for class_ in model.__bases__:
+ if class_ in manager.version_class_map:
+ return manager.version_class_map[class_]
+
+def versioned_parents(manager, model):
+ """
+ Finds all versioned ancestors for current parent model.
+ """
+ for class_ in model.__mro__:
+ if class_ in manager.version_class_map:
+ yield manager.version_class_map[class_]
+
+
+def get_base_class(manager, model):
+ """
+ Returns all base classes for history model.
+ """
+ return (
+ option(model, 'base_classes')
+ or
+ (get_declarative_base(model), )
+ )
+
+
+def version_base(manager, parent_cls, base_class_factory=None):
+ if base_class_factory is None:
+ base_class_factory = get_base_class
+
+ VersionBase = find_closest_versioned_parent(manager, parent_cls)
+
+ if not VersionBase:
+ VersionBase = type(
+ 'VersionBase',
+ (base_class_factory(manager, parent_cls) + (VersionClassBase, )),
+ {'__abstract__': True}
+ )
+
+ return VersionBase
+
+
+def copy_mapper_args(model):
+ args = {}
+ if hasattr(model, '__mapper_args__'):
+ arg_names = (
+ 'with_polymorphic',
+ 'polymorphic_identity',
+ 'concrete'
+ )
+ for arg in arg_names:
+ if arg in model.__mapper_args__:
+ args[arg] = (
+ model.__mapper_args__[arg]
+ )
+
+ if 'order_by' in model.__mapper_args__:
+ arg = model.__mapper_args__['order_by']
+ # Only allow string based order_by reflection to version
+ # classes.
+ if isinstance(arg, six.string_types):
+ args['order_by'] = arg
+
+ if 'polymorphic_on' in model.__mapper_args__:
+ column = model.__mapper_args__['polymorphic_on']
+ if isinstance(column, six.string_types):
+ args['polymorphic_on'] = column
+ else:
+ args['polymorphic_on'] = column.key
+ return args
+
+
class ModelBuilder(object):
"""
VersionedModelBuilder handles the building of Version models based on
@@ -30,15 +107,16 @@ class represents).
"""
conditions = []
foreign_keys = []
- for primary_key in primary_keys(self.model):
- conditions.append(
- getattr(self.model, primary_key.name)
- ==
- getattr(self.version_class, primary_key.name)
- )
- foreign_keys.append(
- getattr(self.version_class, primary_key.name)
- )
+ for key, column in sa.inspect(self.model).columns.items():
+ if column.primary_key:
+ conditions.append(
+ getattr(self.model, key)
+ ==
+ getattr(self.version_class, key)
+ )
+ foreign_keys.append(
+ getattr(self.version_class, key)
+ )
# We need to check if versions relation was already set for parent
# class.
@@ -47,6 +125,10 @@ class represents).
self.version_class,
primaryjoin=sa.and_(*conditions),
foreign_keys=foreign_keys,
+ order_by=lambda: getattr(
+ self.version_class,
+ option(self.model, 'transaction_column_name')
+ ),
lazy='dynamic',
backref=sa.orm.backref(
'version_parent'
@@ -54,20 +136,16 @@ class represents).
viewonly=True
)
- def build_transaction_relationship(self, tx_log_class):
+ def build_transaction_relationship(self, tx_class):
"""
Builds a relationship between currently built version class and
Transaction class.
- :param tx_log_class: Transaction class
+ :param tx_class: Transaction class
"""
# Only define transaction relation if it doesn't already exist in
# parent class.
- backref_name = option(self.model, 'relation_naming_function')(
- self.model.__name__
- )
-
transaction_column = getattr(
self.version_class,
option(self.model, 'transaction_column_name')
@@ -75,72 +153,115 @@ def build_transaction_relationship(self, tx_log_class):
if not hasattr(self.version_class, 'transaction'):
self.version_class.transaction = sa.orm.relationship(
- tx_log_class,
- primaryjoin=tx_log_class.id == transaction_column,
+ tx_class,
+ primaryjoin=tx_class.id == transaction_column,
foreign_keys=[transaction_column],
- backref=backref_name
- )
- else:
- setattr(
- tx_log_class,
- backref_name,
- sa.orm.relationship(
- self.version_class,
- primaryjoin=tx_log_class.id == transaction_column,
- foreign_keys=[transaction_column]
- )
)
- def find_closest_versioned_parent(self):
- """
- Finds the closest versioned parent for current parent model.
- """
- for class_ in self.model.__bases__:
- if class_ in self.manager.version_class_map:
- return (self.manager.version_class_map[class_], )
-
def base_classes(self):
"""
Returns all base classes for history model.
"""
- parents = (
- self.find_closest_versioned_parent()
- or option(self.model, 'base_classes')
- or (declarative_base(self.model), )
- )
- return parents + (VersionClassBase, )
+ return (version_base(self.manager, self.model), )
- def inheritance_args(self):
+ def inheritance_args(self, cls, version_table, table):
"""
Return mapper inheritance args for currently built history model.
"""
- if self.find_closest_versioned_parent():
- reflector = ClassExpressionReflector(self.model)
- mapper = sa.inspect(self.model)
- inherit_condition = reflector(mapper.inherit_condition)
+ args = {}
+
+ if not sa.inspect(self.model).single:
+ parent = find_closest_versioned_parent(
+ self.manager, self.model
+ )
+ if parent:
+ # The version classes do not contain foreign keys, hence we
+ # need to map inheritance condition manually for classes that
+ # use joined table inheritance
+ if parent.__table__.name != table.name:
+ mapper = sa.inspect(self.model)
+
+ inherit_condition = adapt_columns(
+ mapper.inherit_condition
+ )
+ tx_column_name = self.manager.options[
+ 'transaction_column_name'
+ ]
+ args['inherit_condition'] = sa.and_(
+ inherit_condition,
+ getattr(parent.__table__.c, tx_column_name) ==
+ getattr(cls.__table__.c, tx_column_name)
+ )
+ args['inherit_foreign_keys'] = [
+ version_table.c[column.key]
+ for column in sa.inspect(self.model).columns
+ if column.primary_key
+ ]
+
+ args.update(copy_mapper_args(self.model))
+
+ return args
- return {
- 'inherit_condition': inherit_condition
- }
- return {}
+ def get_inherited_denormalized_columns(self, table):
+ parent_models = list(versioned_parents(self.manager, self.model))
+ mapper = sa.inspect(self.model)
+ args = {}
+
+ if parent_models and not (mapper.single or mapper.concrete):
+ columns = [
+ self.manager.option(self.model, 'operation_type_column_name'),
+ self.manager.option(self.model, 'transaction_column_name')
+ ]
+ if self.manager.option(self.model, 'strategy') == 'validity':
+ columns.append(
+ self.manager.option(
+ self.model,
+ 'end_transaction_column_name'
+ )
+ )
+
+ for column in columns:
+ args[column] = column_property(
+ table.c[column],
+ *[m.__table__.c[column] for m in parent_models]
+ )
+ return args
def build_model(self, table):
"""
Build history model class.
"""
- mapper_args = {}
- mapper_args.update(self.inheritance_args())
-
- return type(
- '%sVersion' % self.model.__name__,
- self.base_classes(),
- {
- '__table__': table,
- '__mapper_args__': mapper_args
- }
- )
+ args = {}
+
+ @declared_attr
+ def mapper_args(cls):
+ mapper_args = {}
+ mapper_args.update(self.inheritance_args(
+ cls, table, self.model.__table__)
+ )
+ return mapper_args
+
+ args['__mapper_args__'] = mapper_args
+ args['__versioning_manager__'] = self.manager
+ args['__version_parent__'] = self.model
+
+ parent = find_closest_versioned_parent(self.manager, self.model)
+
+ if not parent or parent.__table__.name != table.name:
+ args['__table__'] = table
+
+ args.update(self.get_inherited_denormalized_columns(table))
+
+ if self.manager.options.get('use_module_name', True):
+ name = '%s%sVersion' % (
+ self.model.__module__.title().replace('.', ''),
+ self.model.__name__
+ )
+ else:
+ name = '%sVersion' % (self.model.__name__,)
+ return type(name, self.base_classes(), args)
- def __call__(self, table, tx_log_class):
+ def __call__(self, table, tx_class):
"""
Build history model and relationships to parent model, transaction
log model.
@@ -152,8 +273,5 @@ def __call__(self, table, tx_log_class):
self.model.__versioning_manager__ = self.manager
self.version_class = self.build_model(table)
self.build_parent_relationship()
- self.build_transaction_relationship(tx_log_class)
- self.version_class.__versioning_manager__ = self.manager
- self.manager.version_class_map[self.model] = self.version_class
- self.manager.parent_class_map[self.version_class] = self.model
+ self.build_transaction_relationship(tx_class)
return self.version_class
diff --git a/sqlalchemy_continuum/operation.py b/sqlalchemy_continuum/operation.py
index 6aea5685..72c72b2d 100644
--- a/sqlalchemy_continuum/operation.py
+++ b/sqlalchemy_continuum/operation.py
@@ -59,6 +59,9 @@ def __bool__(self):
def __nonzero__(self):
return self.__bool__()
+ def __repr__(self):
+ return repr(self.objects)
+
@property
def entities(self):
"""
diff --git a/sqlalchemy_continuum/plugins/activity.py b/sqlalchemy_continuum/plugins/activity.py
index 09c8f9db..8905079a 100644
--- a/sqlalchemy_continuum/plugins/activity.py
+++ b/sqlalchemy_continuum/plugins/activity.py
@@ -67,6 +67,10 @@
session.commit()
+Targets and objects of given activity must have an integer primary key
+column id.
+
+
Create activities
^^^^^^^^^^^^^^^^^
@@ -84,7 +88,7 @@
article = Article(name=u'Some article')
session.add(article)
- session.commit()
+ session.flush()
first_activity = Activity(verb=u'create', object=article)
session.add(first_activity)
session.commit()
@@ -161,7 +165,7 @@
Now if we wanted to find all the changes that affected given article we could
-do so by searching trhough all the activities where either the object or
+do so by searching through all the activities where either the object or
target is the given article.
@@ -182,12 +186,12 @@
.. _activity stream specification:
http://www.activitystrea.ms
.. _generic relationships:
- http://sqlalchemy-utils.readthedocs.org/en/latest/generic_relationship.html
+ https://sqlalchemy-utils.readthedocs.io/en/latest/generic_relationship.html
"""
import sqlalchemy as sa
from sqlalchemy.ext.hybrid import hybrid_property
-from sqlalchemy_utils import JSONType, generates, generic_relationship
+from sqlalchemy_utils import JSONType, generic_relationship
from .base import Plugin
from ..factory import ModelFactory
@@ -195,13 +199,18 @@
class ActivityBase(object):
- id = sa.Column(sa.BigInteger, primary_key=True, autoincrement=True)
+ id = sa.Column(
+ sa.BigInteger,
+ sa.schema.Sequence('activity_id_seq'),
+ primary_key=True,
+ autoincrement=True
+ )
verb = sa.Column(sa.Unicode(255))
@hybrid_property
def actor(self):
- self.transaction.user
+ return self.transaction.user
class ActivityFactory(ModelFactory):
@@ -238,36 +247,25 @@ class Activity(
target_tx_id = sa.Column(sa.BigInteger)
- @generates(object_tx_id)
- def generate_object_tx_id(self):
+ def _calculate_tx_id(self, obj):
session = sa.orm.object_session(self)
- if self.object:
- object_version = version_obj(session, self.object)
+ if obj:
+ object_version = version_obj(session, obj)
if object_version:
return object_version.transaction_id
- version_cls = version_class(self.object.__class__)
+ version_cls = version_class(obj.__class__)
return session.query(
sa.func.max(version_cls.transaction_id)
).filter(
- version_cls.id == self.object_id
+ version_cls.id == obj.id
).scalar()
- @generates(target_tx_id)
- def generate_target_tx_id(self):
- session = sa.orm.object_session(self)
- if self.target:
- target_version = version_obj(session, self.target)
+ def calculate_object_tx_id(self):
+ self.object_tx_id = self._calculate_tx_id(self.object)
- if target_version:
- return target_version.transaction_id
-
- version_cls = version_class(self.target.__class__)
- return session.query(
- sa.func.max(version_cls.transaction_id)
- ).filter(
- version_cls.id == self.target_id
- ).scalar()
+ def calculate_target_tx_id(self):
+ self.target_tx_id = self._calculate_tx_id(self.target)
object = generic_relationship(
object_type, object_id
@@ -333,6 +331,8 @@ def before_flush(self, uow, session):
for obj in session:
if isinstance(obj, self.activity_cls):
obj.transaction = uow.current_transaction
+ obj.calculate_target_tx_id()
+ obj.calculate_object_tx_id()
def after_version_class_built(self, parent_cls, version_cls):
pass
diff --git a/sqlalchemy_continuum/plugins/base.py b/sqlalchemy_continuum/plugins/base.py
index c27c6fe5..a15db71e 100644
--- a/sqlalchemy_continuum/plugins/base.py
+++ b/sqlalchemy_continuum/plugins/base.py
@@ -23,11 +23,8 @@ def after_create_version_objects(self, uow, session):
def after_create_version_object(self, uow, parent_obj, version_obj):
pass
- def before_create_tx_object(self, uow, session):
- pass
-
- def after_create_tx_object(self, uow, session):
- pass
+ def transaction_args(self, uow, session):
+ return {}
def after_version_class_built(self, parent_cls, version_cls):
pass
diff --git a/sqlalchemy_continuum/plugins/flask.py b/sqlalchemy_continuum/plugins/flask.py
index 1f56a454..c7b14254 100644
--- a/sqlalchemy_continuum/plugins/flask.py
+++ b/sqlalchemy_continuum/plugins/flask.py
@@ -23,7 +23,6 @@
import flask
from flask import request
from flask.globals import _app_ctx_stack, _request_ctx_stack
- from flask.ext.login import current_user
except ImportError:
pass
from sqlalchemy_utils import ImproperlyConfigured
@@ -31,6 +30,8 @@
def fetch_current_user_id():
+ from flask_login import current_user
+
# Return None if we are outside of request context.
if _app_ctx_stack.top is None or _request_ctx_stack.top is None:
return
@@ -48,13 +49,28 @@ def fetch_remote_addr():
class FlaskPlugin(Plugin):
- def __init__(self):
+ def __init__(
+ self,
+ current_user_id_factory=None,
+ remote_addr_factory=None
+ ):
+ self.current_user_id_factory = (
+ fetch_current_user_id if current_user_id_factory is None
+ else current_user_id_factory
+ )
+ self.remote_addr_factory = (
+ fetch_remote_addr if remote_addr_factory is None
+ else remote_addr_factory
+ )
+
if not flask:
raise ImproperlyConfigured(
'Flask is required with FlaskPlugin. Please install Flask by'
' running pip install Flask'
)
- def before_create_tx_object(self, uow, session):
- uow.current_transaction.user_id = fetch_current_user_id()
- uow.current_transaction.remote_addr = fetch_remote_addr()
+ def transaction_args(self, uow, session):
+ return {
+ 'user_id': self.current_user_id_factory(),
+ 'remote_addr': self.remote_addr_factory()
+ }
diff --git a/sqlalchemy_continuum/plugins/property_mod_tracker.py b/sqlalchemy_continuum/plugins/property_mod_tracker.py
index 816e5498..498eb07f 100644
--- a/sqlalchemy_continuum/plugins/property_mod_tracker.py
+++ b/sqlalchemy_continuum/plugins/property_mod_tracker.py
@@ -24,25 +24,33 @@
class PropertyModTrackerPlugin(Plugin):
column_suffix = '_mod'
+ def create_mod_column(self, column):
+ return sa.Column(
+ column.name + self.column_suffix,
+ sa.Boolean,
+ key=column.key + self.column_suffix,
+ default=False,
+ server_default=sa.sql.expression.false(),
+ nullable=False
+ )
+
def after_build_version_table_columns(self, table_builder, columns):
- for column in table_builder.parent_table.c:
- if not table_builder.manager.is_excluded_column(
- table_builder.model, column
- ) and not column.primary_key:
- columns.append(
- sa.Column(
- column.name + self.column_suffix,
- sa.Boolean,
- key=column.key + self.column_suffix,
- default=False,
- server_default=sa.sql.expression.false(),
- nullable=False
- )
- )
+ # Only create modification tracking columns for tables that are
+ # associated with actual model classes. In other words do not create
+ # mod tracking columns for association tables.
+ if table_builder.model:
+ for column in table_builder.parent_table.c:
+ if not table_builder.manager.is_excluded_column(
+ table_builder.model, column
+ ) and not column.primary_key:
+ columns.append(self.create_mod_column(column))
def after_create_version_object(self, uow, parent_obj, version_obj):
+ session = sa.orm.object_session(parent_obj)
+ is_deleted = parent_obj in session.deleted
+
for prop in versioned_column_properties(parent_obj):
- if has_changes(parent_obj, prop.key):
+ if has_changes(parent_obj, prop.key) or is_deleted:
setattr(
version_obj,
prop.key + self.column_suffix,
diff --git a/sqlalchemy_continuum/plugins/transaction_changes.py b/sqlalchemy_continuum/plugins/transaction_changes.py
index 10b24209..f4fc9cb5 100644
--- a/sqlalchemy_continuum/plugins/transaction_changes.py
+++ b/sqlalchemy_continuum/plugins/transaction_changes.py
@@ -94,22 +94,4 @@ def ater_commit(self, uow, session):
self.clear()
def after_version_class_built(self, parent_cls, version_cls):
- transaction_column = getattr(
- version_cls,
- option(parent_cls, 'transaction_column_name')
- )
-
- # Only define changes relation if it doesn't already exist in
- # parent class.
- if not hasattr(version_cls, 'changes'):
- version_cls.changes = sa.orm.relationship(
- self.model_class,
- primaryjoin=(
- self.model_class.transaction_id == transaction_column
- ),
- foreign_keys=[self.model_class.transaction_id],
- backref=option(parent_cls, 'relation_naming_function')(
- parent_cls.__name__
- )
- )
parent_cls.__versioned__['transaction_changes'] = self.model_class
diff --git a/sqlalchemy_continuum/plugins/transaction_meta.py b/sqlalchemy_continuum/plugins/transaction_meta.py
index 30e7d260..773f9adb 100644
--- a/sqlalchemy_continuum/plugins/transaction_meta.py
+++ b/sqlalchemy_continuum/plugins/transaction_meta.py
@@ -26,8 +26,8 @@
article = Article()
session.add(article)
- uow = unit_of_work(session)
- tx = uow.create_transaction()
+ uow = versioning_manager.unit_of_work(session)
+ tx = uow.create_transaction(session)
tx.meta = {u'some_key': u'some value'}
session.commit()
diff --git a/sqlalchemy_continuum/relationship_builder.py b/sqlalchemy_continuum/relationship_builder.py
index ed8d9e9c..f3dc368d 100644
--- a/sqlalchemy_continuum/relationship_builder.py
+++ b/sqlalchemy_continuum/relationship_builder.py
@@ -1,8 +1,10 @@
import sqlalchemy as sa
-from .table_builder import TableBuilder
-from .expression_reflector import ObjectExpressionReflector
+
+from .exc import ClassNotVersioned
+from .expression_reflector import VersionExpressionReflector
from .operation import Operation
-from .utils import version_table, version_class, option
+from .table_builder import TableBuilder
+from .utils import adapt_columns, version_class, option
class RelationshipBuilder(object):
@@ -12,30 +14,41 @@ def __init__(self, versioning_manager, model, property_):
self.model = model
def one_to_many_subquery(self, obj):
- primary_keys = []
-
tx_column = option(obj, 'transaction_column_name')
- for column in self.remote_cls.__table__.c:
- if column.primary_key and column.name != tx_column:
- primary_keys.append(column)
+ remote_alias = sa.orm.aliased(self.remote_cls)
+ primary_keys = [
+ getattr(remote_alias, column.name) for column
+ in sa.inspect(remote_alias).mapper.columns
+ if column.primary_key and column.name != tx_column
+ ]
- return getattr(self.remote_cls, tx_column).in_(
+ return sa.exists(
sa.select(
- [sa.func.max(getattr(self.remote_cls, tx_column))]
+ [1]
).where(
- getattr(self.remote_cls, tx_column) <=
- getattr(obj, tx_column)
+ sa.and_(
+ getattr(remote_alias, tx_column) <=
+ getattr(obj, tx_column),
+ *[
+ getattr(remote_alias, pk.name) ==
+ getattr(self.remote_cls, pk.name)
+ for pk in primary_keys
+ ]
+ )
).group_by(
*primary_keys
- ).correlate(self.local_cls)
+ ).having(
+ sa.func.max(getattr(remote_alias, tx_column)) ==
+ getattr(self.remote_cls, tx_column)
+ ).correlate(self.local_cls, self.remote_cls)
)
def many_to_one_subquery(self, obj):
tx_column = option(obj, 'transaction_column_name')
- reflector = ObjectExpressionReflector(obj)
+ reflector = VersionExpressionReflector(obj, self.property)
- return getattr(self.remote_cls, tx_column).in_(
+ return getattr(self.remote_cls, tx_column) == (
sa.select(
[sa.func.max(getattr(self.remote_cls, tx_column))]
).where(
@@ -44,7 +57,7 @@ def many_to_one_subquery(self, obj):
getattr(obj, tx_column),
reflector(self.property.primaryjoin)
)
- ).correlate(self.local_cls)
+ )
)
def query(self, obj):
@@ -71,32 +84,94 @@ def process_query(self, query):
def criteria(self, obj):
direction = self.property.direction
- if direction.name == 'ONETOMANY':
- return self.one_to_many_criteria(obj)
- elif direction.name == 'MANYTOMANY':
- return self.many_to_many_criteria(obj)
- elif direction.name == 'MANYTOONE':
- return self.many_to_one_criteria(obj)
+
+ if self.versioned:
+ if direction.name == 'ONETOMANY':
+ return self.one_to_many_criteria(obj)
+ elif direction.name == 'MANYTOMANY':
+ return self.many_to_many_criteria(obj)
+ elif direction.name == 'MANYTOONE':
+ return self.many_to_one_criteria(obj)
+ else:
+ reflector = VersionExpressionReflector(obj, self.property)
+ return reflector(self.property.primaryjoin)
def many_to_many_criteria(self, obj):
- tx_column = option(obj, 'transaction_column_name')
- condition = (
- getattr(self.remote_cls, tx_column) == sa.select(
- [sa.func.max(getattr(self.remote_cls, tx_column))]
- ).where(
- sa.and_(
- getattr(self.remote_cls, tx_column) <=
- getattr(obj, tx_column),
- )
- ).correlate(self.local_cls)
+ """
+ Returns the many-to-many query.
+
+ Looks up remote items through associations and for each item returns
+ returns the last version with a transaction less than or equal to the
+ transaction of `obj`. This must hold true for both the association and
+ the remote relation items.
+
+ Example
+ -------
+ Select all tags of article with id 3 and transaction 5
+
+ .. code-block:: sql
+
+ SELECT tags_version.*
+ FROM tags_version
+ WHERE EXISTS (
+ SELECT 1
+ FROM article_tag_version
+ WHERE article_id = 3
+ AND tag_id = tags_version.id
+ AND operation_type != 2
+ AND EXISTS (
+ SELECT 1
+ FROM article_tag_version as article_tag_version2
+ WHERE article_tag_version2.tag_id = article_tag_version.tag_id
+ AND article_tag_version2.tx_id <= 5
+ GROUP BY article_tag_version2.tag_id
+ HAVING
+ MAX(article_tag_version2.tx_id) =
+ article_tag_version.tx_id
+ )
+ )
+ AND EXISTS (
+ SELECT 1
+ FROM tags_version as tags_version_2
+ WHERE tags_version_2.id = tags_version.id
+ AND tags_version_2.tx_id <= 5
+ GROUP BY tags_version_2.id
+ HAVING MAX(tags_version_2.tx_id) = tags_version.tx_id
)
+ AND operation_type != 2
+ """
return sa.and_(
- self.remote_cls.id.in_(self.association_subquery(obj)),
- condition
+ self.association_subquery(obj),
+ self.one_to_many_subquery(obj),
+ self.remote_cls.operation_type != Operation.DELETE
)
def many_to_one_criteria(self, obj):
- reflector = ObjectExpressionReflector(obj)
+ """Returns the many-to-one query.
+
+ Returns the item on the 'one' side with the highest transaction id
+ as long as it is less or equal to the transaction id of the `obj`.
+
+ Example
+ -------
+ Look up the Article of a Tag with article_id = 4 and
+ transaction_id = 5
+
+ .. code-block:: sql
+
+ SELECT *
+ FROM articles_version
+ WHERE id = 4
+ AND transaction_id = (
+ SELECT max(transaction_id)
+ FROM articles_version
+ WHERE transaction_id <= 5
+ AND id = 4
+ )
+ AND operation_type != 2
+
+ """
+ reflector = VersionExpressionReflector(obj, self.property)
return sa.and_(
reflector(self.property.primaryjoin),
self.many_to_one_subquery(obj),
@@ -104,7 +179,37 @@ def many_to_one_criteria(self, obj):
)
def one_to_many_criteria(self, obj):
- reflector = ObjectExpressionReflector(obj)
+ """
+ Returns the one-to-many query.
+
+ For each item on the 'many' side, returns its latest version as long as
+ the transaction of that version is less than equal of the transaction
+ of `obj`.
+
+ Example
+ -------
+ Using the Article-Tags relationship, where we look for tags of
+ article_version with id = 3 and transaction = 5 the sql produced is
+
+ .. code-block:: sql
+
+ SELECT tags_version.*
+ FROM tags_version
+ WHERE tags_version.article_id = 3
+ AND tags_version.operation_type != 2
+ AND EXISTS (
+ SELECT 1
+ FROM tags_version as tags_version_last
+ WHERE tags_version_last.transaction_id <= 5
+ AND tags_version_last.id = tags_version.id
+ GROUP BY tags_version_last.id
+ HAVING
+ MAX(tags_version_last.transaction_id) =
+ tags_version.transaction_id
+ )
+
+ """
+ reflector = VersionExpressionReflector(obj, self.property)
return sa.and_(
reflector(self.property.primaryjoin),
self.one_to_many_subquery(obj),
@@ -125,56 +230,76 @@ def relationship(obj):
def association_subquery(self, obj):
"""
- Returns association subquery for given SQLAlchemy declarative object.
- This query is used by many_to_many_criteria method.
+ Returns an EXISTS clause that checks if an association exists for given
+ SQLAlchemy declarative object. This query is used by
+ many_to_many_criteria method.
Example query:
- SELECT article_tag_version.tag_id
- FROM article_tag_version
- WHERE
- article_tag_version.transaction_id IN (
- SELECT max(article_tag_version.transaction_id) AS max_1
- FROM article_tag_version
- WHERE
- article_tag_version.transaction_id <= ? AND
- article_tag_version.article_id = ?
- GROUP BY article_tag_version.tag_id
- ) AND
- article_tag_version.article_id = ? AND
- article_tag_version.operation_type != ?
+ .. code-block:: sql
+ EXISTS (
+ SELECT 1
+ FROM article_tag_version
+ WHERE article_id = 3
+ AND tag_id = tags_version.id
+ AND operation_type != 2
+ AND EXISTS (
+ SELECT 1
+ FROM article_tag_version as article_tag_version2
+ WHERE article_tag_version2.tag_id = article_tag_version.tag_id
+ AND article_tag_version2.tx_id <=5
+ GROUP BY article_tag_version2.tag_id
+ HAVING
+ MAX(article_tag_version2.tx_id) =
+ article_tag_version.tx_id
+ )
+ )
:param obj: SQLAlchemy declarative object
"""
+
tx_column = option(obj, 'transaction_column_name')
- reflector = ObjectExpressionReflector(obj)
- subquery = (
- getattr(self.remote_table.c, tx_column).in_(
- sa.select(
- [sa.func.max(getattr(self.remote_table.c, tx_column))],
- ).where(
- sa.and_(
- getattr(self.remote_table.c, tx_column) <=
- getattr(obj, tx_column),
- reflector(self.property.primaryjoin)
- )
- ).group_by(
- self.remote_table.c[self.remote_column.name]
- ).correlate(self.local_cls)
- )
- )
+ reflector = VersionExpressionReflector(obj, self.property)
- return (
+ association_table_alias = self.association_version_table.alias()
+ association_cols = [
+ association_table_alias.c[association_col.name]
+ for _, association_col
+ in self.remote_to_association_column_pairs
+ ]
+
+ association_exists = sa.exists(
sa.select(
- [self.remote_table.c[self.remote_column.name]]
+ [1]
+ ).where(
+ sa.and_(
+ association_table_alias.c[tx_column] <=
+ getattr(obj, tx_column),
+ *[association_col ==
+ self.association_version_table.c[association_col.name]
+ for association_col
+ in association_cols]
+ )
+ ).group_by(
+ *association_cols
+ ).having(
+ sa.func.max(association_table_alias.c[tx_column]) ==
+ self.association_version_table.c[tx_column]
+ ).correlate(self.association_version_table)
+ )
+ return sa.exists(
+ sa.select(
+ [1]
).where(
sa.and_(
- subquery,
reflector(self.property.primaryjoin),
- self.remote_table.c.operation_type != Operation.DELETE
+ association_exists,
+ self.association_version_table.c.operation_type !=
+ Operation.DELETE,
+ adapt_columns(self.property.secondaryjoin),
)
- )
+ ).correlate(self.local_cls, self.remote_cls)
)
def build_association_version_tables(self):
@@ -191,15 +316,20 @@ def build_association_version_tables(self):
column.table
)
metadata = column.table.metadata
- if metadata.schema:
+ if builder.parent_table.schema:
+ table_name = builder.parent_table.schema + '.' + builder.table_name
+ elif metadata.schema:
table_name = metadata.schema + '.' + builder.table_name
else:
table_name = builder.table_name
if table_name not in metadata.tables:
- table = builder()
-
+ self.association_version_table = table = builder()
self.manager.association_version_tables.add(table)
+ else:
+ # may have already been created if we visiting the 'other' side of
+ # a self-referential many-to-many relationship
+ self.association_version_table = metadata.tables[table_name]
def __call__(self):
"""
@@ -207,20 +337,27 @@ def __call__(self):
parent object's RelationshipProperty.
"""
self.local_cls = version_class(self.model)
+ self.versioned = False
try:
self.remote_cls = version_class(self.property.mapper.class_)
+ self.versioned = True
except (AttributeError, KeyError):
return
+ except ClassNotVersioned:
+ self.remote_cls = self.property.mapper.class_
- if self.property.secondary is not None:
+ if (self.property.secondary is not None and
+ not self.property.viewonly and
+ not self.manager.is_excluded_property(
+ self.model, self.property.key)):
self.build_association_version_tables()
+ # store remote cls to association table column pairs
+ self.remote_to_association_column_pairs = []
for column_pair in self.property.local_remote_pairs:
if column_pair[0] in self.property.table.c.values():
- self.remote_column = column_pair[1]
- break
+ self.remote_to_association_column_pairs.append(column_pair)
- self.remote_table = version_table(self.remote_column.table)
setattr(
self.local_cls,
self.property.key,
diff --git a/sqlalchemy_continuum/reverter.py b/sqlalchemy_continuum/reverter.py
index b2206e9c..68db7679 100644
--- a/sqlalchemy_continuum/reverter.py
+++ b/sqlalchemy_continuum/reverter.py
@@ -20,12 +20,13 @@ class ReverterException(Exception):
class Reverter(object):
- def __init__(self, obj, visited_objects=[], relations=[]):
- self.visited_objects = visited_objects
+ def __init__(self, obj, visited_objects=None, relations=[]):
+ self.visited_objects = visited_objects or []
self.obj = obj
self.version_parent = self.obj.version_parent
self.parent_class = parent_class(self.obj.__class__)
self.parent_mapper = sa.inspect(self.parent_class)
+ self.session = sa.orm.object_session(self.obj)
self.relations = list(relations)
for path in relations:
@@ -70,8 +71,15 @@ def revert_relationship(self, prop):
self.revert_association(prop)
else:
if prop.uselist:
- for value in getattr(self.obj, prop.key):
- self.revert_child(value, prop)
+ values = []
+ for child_obj in getattr(self.obj, prop.key):
+ value = self.revert_child(child_obj, prop)
+ if value:
+ values.append(value)
+
+ for value in getattr(self.version_parent, prop.key, []):
+ if value not in values:
+ self.session.delete(value)
else:
self.revert_child(getattr(self.obj, prop.key), prop)
@@ -95,12 +103,13 @@ def revert_relationships(self):
def __call__(self):
if self.obj in self.visited_objects:
- return
-
- session = sa.orm.object_session(self.obj)
+ return (
+ None if self.obj.operation_type == Operation.DELETE
+ else self.version_parent
+ )
if self.obj.operation_type == Operation.DELETE:
- session.delete(self.version_parent)
+ self.session.delete(self.version_parent)
return
self.visited_objects.append(self.obj)
@@ -108,7 +117,6 @@ def __call__(self):
# Check if parent object has been deleted
if self.version_parent is None:
self.version_parent = parent_class(self.obj.__class__)()
- session.add(self.version_parent)
# Before reifying relations we need to reify object properties. This
# is needed because reifying relations might need to flush the session
@@ -116,5 +124,6 @@ def __call__(self):
# into parent object (if parent object has not null constraints).
self.revert_properties()
self.revert_relationships()
+ self.session.add(self.version_parent)
return self.version_parent
diff --git a/sqlalchemy_continuum/schema.py b/sqlalchemy_continuum/schema.py
index e0c1a127..659df1b6 100644
--- a/sqlalchemy_continuum/schema.py
+++ b/sqlalchemy_continuum/schema.py
@@ -65,7 +65,9 @@ def update_end_tx_column(
executing the queries.
"""
if conn is None:
- from alembic import op as conn
+ from alembic import op
+
+ conn = op.get_bind()
query = get_end_tx_column_query(
table,
@@ -152,7 +154,9 @@ def update_property_mod_flags(
executing the queries.
"""
if conn is None:
- from alembic import op as conn
+ from alembic import op
+
+ conn = op.get_bind()
query = get_property_mod_flags_query(
table,
diff --git a/sqlalchemy_continuum/table_builder.py b/sqlalchemy_continuum/table_builder.py
index 358c25b9..12381247 100644
--- a/sqlalchemy_continuum/table_builder.py
+++ b/sqlalchemy_continuum/table_builder.py
@@ -1,20 +1,12 @@
import sqlalchemy as sa
+from sqlalchemy_utils import get_column_key
-class TableBuilder(object):
- """
- TableBuilder handles the building of version tables based on parent
- table's structure and versioning configuration options.
- """
- def __init__(
- self,
- versioning_manager,
- parent_table,
- model=None
- ):
- self.manager = versioning_manager
+class ColumnReflector(object):
+ def __init__(self, manager, parent_table, model=None):
self.parent_table = parent_table
self.model = model
+ self.manager = manager
def option(self, name):
try:
@@ -22,50 +14,6 @@ def option(self, name):
except TypeError:
return self.manager.options[name]
- @property
- def table_name(self):
- """
- Returns the version table name for current parent table.
- """
- return self.option('table_name') % self.parent_table.name
-
- @property
- def parent_columns(self):
- for column in self.parent_table.c:
- if (
- self.model and
- self.manager.is_excluded_column(self.model, column)
- ):
- continue
- if not self.model and column in self.manager.options['exclude']:
- continue
- yield column
-
- @property
- def reflected_columns(self):
- """
- Returns reflected parent table columns.
-
- All columns from parent table are reflected except those that:
- 1. Are auto assigned date or datetime columns. Use include option
- parameter if you wish to have these included.
- 2. Columns that are part of exclude option parameter.
- """
- columns = []
-
- transaction_column_name = self.option('transaction_column_name')
-
- for column in self.parent_columns:
- column_copy = self.reflect_column(column)
- columns.append(column_copy)
-
- # When using join table inheritance each table should have own
- # transaction column.
- if transaction_column_name not in [c.key for c in columns]:
- columns.append(sa.Column(transaction_column_name, sa.BigInteger))
-
- return columns
-
def reflect_column(self, column):
"""
Make a copy of parent table column and some alterations to it.
@@ -133,25 +81,90 @@ def end_transaction_column(self):
index=True
)
+ @property
+ def reflected_parent_columns(self):
+ for column in self.parent_table.c:
+ if (
+ self.model and
+ self.manager.is_excluded_column(self.model, column)
+ ):
+ continue
+ reflected_column = self.reflect_column(column)
+ yield reflected_column
+
+ def __iter__(self):
+ for column in self.reflected_parent_columns:
+ yield column
+
+ # Only yield internal version columns if parent model is not using
+ # single table inheritance
+ if not self.model or not sa.inspect(self.model).single:
+ yield self.transaction_column
+ if self.option('strategy') == 'validity':
+ yield self.end_transaction_column
+ yield self.operation_type_column
+
+
+class TableBuilder(object):
+ """
+ TableBuilder handles the building of version tables based on parent
+ table's structure and versioning configuration options.
+ """
+
+ def __init__(
+ self,
+ versioning_manager,
+ parent_table,
+ model=None
+ ):
+ self.manager = versioning_manager
+ self.parent_table = parent_table
+ self.model = model
+
+ def option(self, name):
+ try:
+ return self.manager.option(self.model, name)
+ except (TypeError, KeyError):
+ try:
+ return self.manager.options[name]
+ except KeyError:
+ return None
+
+ @property
+ def table_name(self):
+ """
+ Returns the version table name for current parent table.
+ """
+ table_name = self.option('table_name') % self.parent_table if '%s' in self.option(
+ 'table_name') else self.option(
+ 'table_name')
+ if '.' in table_name:
+ table_name = table_name.split('.')[-1]
+
+ return table_name
+
@property
def columns(self):
- data = self.reflected_columns
- data.append(self.transaction_column)
- if self.option('strategy') == 'validity':
- data.append(self.end_transaction_column)
- data.append(self.operation_type_column)
- return data
+ return list(
+ column for column in
+ ColumnReflector(self.manager, self.parent_table, self.model)
+ )
def __call__(self, extends=None):
"""
Builds version table.
"""
columns = self.columns if extends is None else []
-
self.manager.plugins.after_build_version_table_columns(self, columns)
+ schema = self.parent_table.schema
+ if self.option('schema_name'):
+ schema = self.option('schema_name')
+ elif self.option('versions_tables_schema_name'):
+ schema = self.option('versions_tables_schema_name')
return sa.schema.Table(
extends.name if extends is not None else self.table_name,
self.parent_table.metadata,
*columns,
+ schema=schema,
extend_existing=extends is not None
)
diff --git a/sqlalchemy_continuum/transaction.py b/sqlalchemy_continuum/transaction.py
index 8daf2c37..f56d84e0 100644
--- a/sqlalchemy_continuum/transaction.py
+++ b/sqlalchemy_continuum/transaction.py
@@ -1,6 +1,19 @@
from datetime import datetime
+
+try:
+ from collections import OrderedDict
+except ImportError:
+ from ordereddict import OrderedDict
+import six
import sqlalchemy as sa
from sqlalchemy.ext.compiler import compiles
+
+from .dialects.postgresql import (
+ CreateTemporaryTransactionTableSQL,
+ InsertTemporaryTransactionSQL,
+ TransactionTriggerSQL
+)
+from .exc import ImproperlyConfigured
from .factory import ModelFactory
@@ -10,7 +23,6 @@ def compile_big_integer(element, compiler, **kw):
class TransactionBase(object):
- id = sa.Column(sa.types.BigInteger, primary_key=True, autoincrement=True)
issued_at = sa.Column(sa.DateTime, default=datetime.utcnow)
@property
@@ -30,58 +42,139 @@ def changed_entities(self):
"""
manager = self.__versioning_manager__
tuples = set(manager.version_class_map.items())
- entities = []
+ entities = {}
- for class_, version_class in tuples:
+ session = sa.orm.object_session(self)
+ for class_, version_class in tuples:
if class_.__name__ not in self.entity_names:
continue
- try:
- value = getattr(
- self,
- manager.options['relation_naming_function'](
- class_.__name__
+ tx_column = manager.option(class_, 'transaction_column_name')
+
+ entities[version_class] = (
+ session
+ .query(version_class)
+ .filter(getattr(version_class, tx_column) == self.id)
+ ).all()
+ return entities
+
+
+procedure_sql = """
+CREATE OR REPLACE FUNCTION transaction_temp_table_generator()
+RETURNS TRIGGER AS $$
+BEGIN
+ {temporary_transaction_sql}
+ INSERT INTO temporary_transaction (id) VALUES (NEW.id);
+ RETURN NEW;
+END;
+$$
+LANGUAGE plpgsql
+"""
+
+
+def create_triggers(cls):
+ sa.event.listen(
+ cls.__table__,
+ 'after_create',
+ sa.schema.DDL(
+ procedure_sql.format(
+ temporary_transaction_sql=CreateTemporaryTransactionTableSQL(),
+ insert_temporary_transaction_sql=(
+ InsertTemporaryTransactionSQL(
+ transaction_id_values='NEW.id'
)
- )
- except AttributeError:
- continue
-
- if value:
- entities.append((
- version_class,
- value
- ))
- return dict(entities)
+ ),
+ )
+ )
+ )
+ sa.event.listen(
+ cls.__table__,
+ 'after_create',
+ sa.schema.DDL(str(TransactionTriggerSQL(cls)))
+ )
+ sa.event.listen(
+ cls.__table__,
+ 'after_drop',
+ sa.schema.DDL(
+ 'DROP FUNCTION IF EXISTS transaction_temp_table_generator()'
+ )
+ )
class TransactionFactory(ModelFactory):
model_name = 'Transaction'
- def __init__(self, user=True, remote_addr=True):
- self.user = user
+ def __init__(self, remote_addr=True):
self.remote_addr = remote_addr
def create_class(self, manager):
"""
Create Transaction class.
"""
+
class Transaction(
manager.declarative_base,
TransactionBase
):
- __tablename__ = 'transaction'
+ __tablename__ = manager.options['transaction_table_name']
__versioning_manager__ = manager
+ __table_args__ = {u'schema': manager.options['transaction_table_schema_name']}
+
+ id = sa.Column(
+ sa.types.BigInteger,
+ primary_key=True,
+ autoincrement=True,
+ server_default=sa.text(
+ "nextval('{}.equipment_seq'::regclass)".format(manager.options['transaction_table_schema_name']))
+ )
if self.remote_addr:
remote_addr = sa.Column(sa.String(50))
- if self.user:
+ if manager.user_cls:
+ user_cls = manager.user_cls
+ registry = manager.declarative_base._decl_class_registry
+
+ if isinstance(user_cls, six.string_types):
+ try:
+ user_cls = registry[user_cls]
+ except KeyError:
+ raise ImproperlyConfigured(
+ 'Could not build relationship between Transaction'
+ ' and %s. %s was not found in declarative class '
+ 'registry. Either configure VersioningManager to '
+ 'use different user class or disable this '
+ 'relationship ' % (user_cls, user_cls)
+ )
+
user_id = sa.Column(
- sa.Integer,
- sa.ForeignKey('user.id'),
+ sa.inspect(user_cls).primary_key[0].type,
+ sa.ForeignKey(sa.inspect(user_cls).primary_key[0]),
index=True
)
- user = sa.orm.relationship('User')
+ user = sa.orm.relationship(user_cls)
+
+ def __repr__(self):
+ fields = ['id', 'issued_at', 'user']
+ field_values = OrderedDict(
+ (field, getattr(self, field))
+ for field in fields
+ if hasattr(self, field)
+ )
+ return '' % ', '.join(
+ (
+ '%s=%r' % (field, value)
+ if not isinstance(value, six.integer_types)
+ # We want the following line to ensure that longs get
+ # shown without the ugly L suffix on python 2.x
+ # versions
+ else '%s=%d' % (field, value)
+ for field, value in field_values.items()
+ )
+ )
+
+ if manager.options['native_versioning']:
+ create_triggers(Transaction)
return Transaction
diff --git a/sqlalchemy_continuum/unit_of_work.py b/sqlalchemy_continuum/unit_of_work.py
index 0dcfee76..5f91b13d 100644
--- a/sqlalchemy_continuum/unit_of_work.py
+++ b/sqlalchemy_continuum/unit_of_work.py
@@ -1,7 +1,7 @@
from copy import copy
import sqlalchemy as sa
-from sqlalchemy_utils import identity
+from sqlalchemy_utils import get_primary_keys, identity
from .operation import Operations
from .utils import (
end_tx_column_name,
@@ -17,7 +17,7 @@ def __init__(self, manager):
self.manager = manager
self.reset()
- def reset(self):
+ def reset(self, session=None):
"""
Reset the internal state of this UnitOfWork object. Normally this is
called after transaction has been committed or rolled back.
@@ -43,6 +43,19 @@ def is_modified(self, session):
)
def process_before_flush(self, session):
+ """
+ Before flush processor for given session.
+
+ This method creates a version session which is later on used for the
+ creation of version objects. It also creates Transaction object for the
+ current transaction and invokes before_flush template method on all
+ plugins.
+
+ If the given session had no relevant modifications regarding versioned
+ objects this method does nothing.
+
+ :param session: SQLAlchemy session object
+ """
if session == self.version_session:
return
@@ -74,19 +87,40 @@ def process_after_flush(self, session):
if not self.current_transaction:
return
+ if not self.version_session:
+ self.version_session = sa.orm.session.Session(
+ bind=session.connection()
+ )
+
self.make_versions(session)
+ def transaction_args(self, session):
+ args = {}
+ for plugin in self.manager.plugins:
+ args.update(plugin.transaction_args(self, session))
+ return args
+
def create_transaction(self, session):
"""
Create transaction object for given SQLAlchemy session.
:param session: SQLAlchemy session object
"""
- self.current_transaction = self.manager.transaction_cls()
- self.manager.plugins.before_create_tx_object(self, session)
- session.add(self.current_transaction)
- self.manager.plugins.after_create_tx_object(self, session)
+ args = self.transaction_args(session)
+
+ Transaction = self.manager.transaction_cls
+ self.current_transaction = Transaction()
+ for key, value in args.items():
+ setattr(self.current_transaction, key, value)
+ if not self.version_session:
+ self.version_session = sa.orm.session.Session(
+ bind=session.connection()
+ )
+ self.version_session.add(self.current_transaction)
+ self.version_session.flush()
+ self.version_session.expunge(self.current_transaction)
+ session.add(self.current_transaction)
return self.current_transaction
def get_or_create_version_object(self, target):
@@ -151,7 +185,10 @@ def create_version_objects(self, session):
:param session: SQLAlchemy session object
"""
- if not self.manager.options['versioning']:
+ if (
+ not self.manager.options['versioning'] or
+ self.manager.options['native_versioning']
+ ):
return
for key, operation in copy(self.operations).items():
@@ -166,7 +203,7 @@ def create_version_objects(self, session):
self.version_session.flush()
- def version_validity_subquery(self, parent, version_obj):
+ def version_validity_subquery(self, parent, version_obj, alias=None):
"""
Return the subquery needed by :func:`update_version_validity`.
@@ -181,11 +218,13 @@ def version_validity_subquery(self, parent, version_obj):
session = sa.orm.object_session(version_obj)
subquery = fetcher._transaction_id_subquery(
- version_obj, next_or_prev='prev'
+ version_obj,
+ next_or_prev='prev',
+ alias=alias
)
if session.connection().engine.dialect.name == 'mysql':
return sa.select(
- ['max_1'],
+ [sa.text('max_1')],
from_obj=[
sa.sql.expression.alias(subquery, name='subquery')
]
@@ -204,29 +243,40 @@ def update_version_validity(self, parent, version_obj):
.. seealso:: :func:`version_validity_subquery`
"""
- fetcher = self.manager.fetcher(parent)
session = sa.orm.object_session(version_obj)
- subquery = self.version_validity_subquery(parent, version_obj)
- query = (
- session.query(version_obj.__class__)
- .filter(
- sa.and_(
- getattr(
- version_obj.__class__,
- tx_column_name(version_obj)
- ) == subquery,
- *fetcher.parent_identity_correlation(version_obj)
+ for class_ in version_obj.__class__.__mro__:
+ if class_ in self.manager.parent_class_map:
+
+ subquery = self.version_validity_subquery(
+ parent,
+ version_obj,
+ alias=sa.orm.aliased(class_.__table__)
+ )
+ query = (
+ session.query(class_.__table__)
+ .filter(
+ sa.and_(
+ getattr(
+ class_,
+ tx_column_name(version_obj)
+ ) == subquery,
+ *[
+ getattr(version_obj, pk) ==
+ getattr(class_.__table__.c, pk)
+ for pk in get_primary_keys(class_)
+ if pk != tx_column_name(class_)
+ ]
+ )
+ )
+ )
+ query.update(
+ {
+ end_tx_column_name(version_obj):
+ self.current_transaction.id
+ },
+ synchronize_session=False
)
- )
- )
- query.update(
- {
- end_tx_column_name(version_obj):
- self.current_transaction.id
- },
- synchronize_session=False
- )
def create_association_versions(self, session):
"""
@@ -236,7 +286,12 @@ def create_association_versions(self, session):
"""
statements = copy(self.pending_statements)
for stmt in statements:
- stmt = stmt.values(transaction_id=self.current_transaction.id)
+ stmt = stmt.values(
+ **{
+ self.manager.options['transaction_column_name']:
+ self.current_transaction.id
+ }
+ )
session.execute(stmt)
self.pending_statements = []
@@ -274,5 +329,8 @@ def assign_attributes(self, parent_obj, version_obj):
Version object to assign the attribute values to
"""
for prop in versioned_column_properties(parent_obj):
- value = getattr(parent_obj, prop.key)
+ try:
+ value = getattr(parent_obj, prop.key)
+ except sa.orm.exc.ObjectDeletedError:
+ value = None
setattr(version_obj, prop.key, value)
diff --git a/sqlalchemy_continuum/utils.py b/sqlalchemy_continuum/utils.py
index 59cfb28a..c622760e 100644
--- a/sqlalchemy_continuum/utils.py
+++ b/sqlalchemy_continuum/utils.py
@@ -1,12 +1,17 @@
+from itertools import chain
from inspect import isclass
from collections import defaultdict
+
import sqlalchemy as sa
-from sqlalchemy.orm import object_session
from sqlalchemy.orm.attributes import get_history
-from sqlalchemy.orm.exc import UnmappedInstanceError
-from sqlalchemy.orm.properties import ColumnProperty
from sqlalchemy.orm.util import AliasedClass
-from sqlalchemy_utils.functions import naturally_equivalent, identity
+from sqlalchemy_utils.functions import (
+ get_primary_keys,
+ identity,
+ naturally_equivalent,
+)
+
+from .exc import ClassNotVersioned
def get_versioning_manager(obj_or_class):
@@ -19,10 +24,20 @@ def get_versioning_manager(obj_or_class):
if isinstance(obj_or_class, AliasedClass):
obj_or_class = sa.inspect(obj_or_class).mapper.class_
cls = obj_or_class if isclass(obj_or_class) else obj_or_class.__class__
- return cls.__versioning_manager__
+ try:
+ return cls.__versioning_manager__
+ except AttributeError:
+ raise ClassNotVersioned(cls.__name__)
def option(obj_or_class, option_name):
+ """
+ Return the option value of given option for given versioned object or
+ class.
+
+ :param obj_or_class: SQLAlchemy declarative model object or class
+ :param option_name: The name of an option to return
+ """
if isinstance(obj_or_class, AliasedClass):
obj_or_class = sa.inspect(obj_or_class).mapper.class_
cls = obj_or_class if isclass(obj_or_class) else obj_or_class.__class__
@@ -48,23 +63,6 @@ def end_tx_attr(obj):
)
-def get_bind(obj):
- if hasattr(obj, 'bind'):
- conn = obj.bind
- else:
- try:
- conn = object_session(obj).bind
- except UnmappedInstanceError:
- conn = obj
-
- if not isinstance(conn, sa.engine.base.Connection):
- raise TypeError(
- 'This method accepts only Session, Connection and declarative '
- 'model objects.'
- )
- return conn
-
-
def parent_class(version_cls):
"""
Return the parent class for given version model class.
@@ -122,7 +120,11 @@ def version_class(model):
.. seealso:: :func:`parent_class`
"""
- return get_versioning_manager(model).version_class_map[model]
+ manager = get_versioning_manager(model)
+ try:
+ return manager.version_class_map[model]
+ except KeyError:
+ return model
def version_table(table):
@@ -131,7 +133,11 @@ def version_table(table):
:param table: SQLAlchemy Table object
"""
- if table.metadata.schema:
+ if table.schema:
+ return table.metadata.tables[
+ table.schema + '.' + table.name + '_version'
+ ]
+ elif table.metadata.schema:
return table.metadata.tables[
table.metadata.schema + '.' + table.name + '_version'
]
@@ -180,7 +186,7 @@ def is_versioned(obj_or_class):
obj_or_class, 'versioning'
)
)
- except (AttributeError, KeyError):
+ except ClassNotVersioned:
return False
@@ -195,14 +201,13 @@ def versioned_column_properties(obj_or_class):
cls = obj_or_class if isclass(obj_or_class) else obj_or_class.__class__
- for prop in sa.inspect(cls).attrs.values():
- if not isinstance(prop, ColumnProperty):
- continue
- if not manager.is_excluded_column(obj_or_class, prop.columns[0]):
- yield prop
+ mapper = sa.inspect(cls)
+ for key in mapper.columns.keys():
+ if not manager.is_excluded_property(obj_or_class, key):
+ yield getattr(mapper.attrs, key)
-def versioned_relationships(obj):
+def versioned_relationships(obj, versioned_column_keys):
"""
Return all versioned relationships for given versioned SQLAlchemy
declarative model object.
@@ -210,11 +215,11 @@ def versioned_relationships(obj):
:param obj: SQLAlchemy declarative model object
"""
for prop in sa.inspect(obj.__class__).relationships:
- if is_versioned(prop.mapper.class_):
+ if any(c.key in versioned_column_keys for c in prop.local_columns):
yield prop
-def vacuum(session, model):
+def vacuum(session, model, yield_per=1000):
"""
When making structural changes to version tables (for example dropping
columns) there are sometimes situations where some old version records
@@ -235,6 +240,7 @@ def vacuum(session, model):
:param session: SQLAlchemy session object
:param model: SQLAlchemy declarative model class
+ :param yield_per: how many rows to process at a time
"""
version_cls = version_class(model)
versions = defaultdict(list)
@@ -242,15 +248,18 @@ def vacuum(session, model):
query = (
session.query(version_cls)
.order_by(option(version_cls, 'transaction_column_name'))
- )
+ ).yield_per(yield_per)
+
+ primary_key_col = sa.inspection.inspect(model).primary_key[0].name
for version in query:
- if versions[version.id]:
- prev_version = versions[version.id][-1]
+ version_id = getattr(version, primary_key_col)
+ if versions[version_id]:
+ prev_version = versions[version_id][-1]
if naturally_equivalent(prev_version, version):
session.delete(version)
else:
- versions[version.id].append(version)
+ versions[version_id].append(version)
def is_internal_column(model, column_name):
@@ -277,7 +286,10 @@ def is_modified_or_deleted(obj):
:param obj: SQLAlchemy declarative model object
"""
session = sa.orm.object_session(obj)
- return is_modified(obj) or obj in session.deleted
+ return is_versioned(obj) and (
+ is_modified(obj) or
+ obj in chain(session.deleted, session.new)
+ )
def is_modified(obj):
@@ -306,7 +318,8 @@ def is_modified(obj):
prop.key for prop in versioned_column_properties(obj)
]
versioned_relationship_keys = [
- prop.key for prop in versioned_relationships(obj)
+ prop.key
+ for prop in versioned_relationships(obj, versioned_column_keys)
]
for key, attr in sa.inspect(obj).attrs.items():
if key in column_names:
@@ -335,6 +348,43 @@ def is_session_modified(session):
)
+def count_versions(obj):
+ """
+ Return the number of versions given object has. This function works even
+ when obj has `create_models` and `create_tables` versioned settings
+ disabled.
+
+ ::
+
+ article = Article(name=u'Some article')
+
+ count_versions(article) # 0
+
+ session.add(article)
+ session.commit()
+
+ count_versions(article) # 1
+
+
+ :param obj: SQLAlchemy declarative model object
+ """
+ session = sa.orm.object_session(obj)
+ if session is None:
+ # If object is transient, we assume it has no version history.
+ return 0
+ manager = get_versioning_manager(obj)
+ table_name = manager.option(obj, 'table_name') % obj.__table__.name
+ criteria = [
+ '%s = %r' % (pk, getattr(obj, pk))
+ for pk in get_primary_keys(obj)
+ ]
+ query = 'SELECT COUNT(1) FROM %s WHERE %s' % (
+ table_name,
+ ' AND '.join(criteria)
+ )
+ return session.execute(query).scalar()
+
+
def changeset(obj):
"""
Return a humanized changeset for given SQLAlchemy declarative object. With
@@ -371,3 +421,15 @@ def changeset(obj):
if new_value:
data[prop.key] = [new_value, old_value]
return data
+
+
+
+class VersioningClauseAdapter(sa.sql.visitors.ReplacingCloningVisitor):
+ def replace(self, col):
+ if isinstance(col, sa.Column):
+ table = version_table(col.table)
+ return table.c.get(col.key)
+
+
+def adapt_columns(expr):
+ return VersioningClauseAdapter().traverse(expr)
diff --git a/sqlalchemy_continuum/version.py b/sqlalchemy_continuum/version.py
index 5c3c1ed2..d71e745d 100644
--- a/sqlalchemy_continuum/version.py
+++ b/sqlalchemy_continuum/version.py
@@ -1,4 +1,5 @@
import sqlalchemy as sa
+
from .reverter import Reverter
from .utils import get_versioning_manager, is_internal_column, parent_class
@@ -49,9 +50,6 @@ def changeset(self):
and second list value as the new value.
"""
previous_version = self.previous
- if not previous_version and self.operation_type != 0:
- return {}
-
data = {}
for key in sa.inspect(self.__class__).columns.keys():
diff --git a/tests/__init__.py b/tests/__init__.py
index b1b70097..310e9bb6 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -8,9 +8,11 @@
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy_continuum import (
+ ClassNotVersioned,
version_class,
make_versioned,
versioning_manager,
+ remove_versioning
)
from sqlalchemy_continuum.transaction import TransactionFactory
from sqlalchemy_continuum.plugins import (
@@ -21,9 +23,6 @@
warnings.simplefilter('error', sa.exc.SAWarning)
-make_versioned(options={'strategy': 'subquery'})
-
-
class QueryPool(object):
queries = []
@@ -40,58 +39,80 @@ def log_sql(
QueryPool.queries.append(statement)
+def get_dns_from_driver(driver):
+ if driver == 'postgres':
+ return 'postgres://postgres@localhost/sqlalchemy_continuum_test'
+ elif driver == 'mysql':
+ return 'mysql+pymysql://travis@localhost/sqlalchemy_continuum_test'
+ elif driver == 'sqlite':
+ return 'sqlite:///:memory:'
+ else:
+ raise Exception('Unknown driver given: %r' % driver)
+
+
+def get_driver_name(driver):
+ return driver[0:-len('-native')] if driver.endswith('-native') else driver
+
+
+def uses_native_versioning():
+ return os.environ.get('DB', 'sqlite').endswith('-native')
+
+
class TestCase(object):
versioning_strategy = 'subquery'
transaction_column_name = 'transaction_id'
end_transaction_column_name = 'end_transaction_id'
composite_pk = False
plugins = [TransactionChangesPlugin(), TransactionMetaPlugin()]
- transaction_cls = TransactionFactory(user=False)
+ transaction_cls = TransactionFactory()
+ user_cls = None
+ should_create_models = True
@property
def options(self):
return {
+ 'create_models': self.should_create_models,
+ 'native_versioning': uses_native_versioning(),
'base_classes': (self.Model, ),
'strategy': self.versioning_strategy,
'transaction_column_name': self.transaction_column_name,
'end_transaction_column_name': self.end_transaction_column_name,
}
- def get_dns_from_driver(self, driver):
- if driver == 'postgres':
- return 'postgres://postgres@localhost/sqlalchemy_continuum_test'
- elif driver == 'mysql':
- return 'mysql+pymysql://travis@localhost/sqlalchemy_continuum_test'
- elif driver == 'sqlite':
- return 'sqlite:///:memory:'
- else:
- raise Exception('Unknown driver given: %r' % driver)
-
def setup_method(self, method):
+ self.Model = declarative_base()
+ make_versioned(options=self.options)
+
driver = os.environ.get('DB', 'sqlite')
+ self.driver = get_driver_name(driver)
versioning_manager.plugins = self.plugins
versioning_manager.transaction_cls = self.transaction_cls
+ versioning_manager.user_cls = self.user_cls
- self.engine = create_engine(self.get_dns_from_driver(driver))
+ self.engine = create_engine(get_dns_from_driver(self.driver))
# self.engine.echo = True
- self.connection = self.engine.connect()
- self.Model = declarative_base()
-
self.create_models()
sa.orm.configure_mappers()
+ self.connection = self.engine.connect()
+
if hasattr(self, 'Article'):
- self.ArticleVersion = version_class(self.Article)
+ try:
+ self.ArticleVersion = version_class(self.Article)
+ except ClassNotVersioned:
+ pass
if hasattr(self, 'Tag'):
try:
self.TagVersion = version_class(self.Tag)
- except (AttributeError, KeyError):
+ except ClassNotVersioned:
pass
self.create_tables()
Session = sessionmaker(bind=self.connection)
- self.session = Session()
+ self.session = Session(autoflush=False)
+ if driver == 'postgres-native':
+ self.session.execute('CREATE EXTENSION IF NOT EXISTS hstore')
def create_tables(self):
self.Model.metadata.create_all(self.connection)
@@ -100,6 +121,11 @@ def drop_tables(self):
self.Model.metadata.drop_all(self.connection)
def teardown_method(self, method):
+ self.session.rollback()
+ uow_leaks = versioning_manager.units_of_work
+ session_map_leaks = versioning_manager.session_connection_map
+
+ remove_versioning()
QueryPool.queries = []
versioning_manager.reset()
@@ -109,6 +135,9 @@ def teardown_method(self, method):
self.engine.dispose()
self.connection.close()
+ assert not uow_leaks
+ assert not session_map_leaks
+
def create_models(self):
class Article(self.Model):
__tablename__ = 'article'
@@ -133,9 +162,18 @@ class Tag(self.Model):
setting_variants = {
- 'versioning_strategy': ['subquery', 'validity'],
- 'transaction_column_name': ['transaction_id', 'tx_id'],
- 'end_transaction_column_name': ['end_transaction_id', 'end_tx_id']
+ 'versioning_strategy': [
+ 'subquery',
+ 'validity',
+ ],
+ 'transaction_column_name': [
+ 'transaction_id',
+ 'tx_id'
+ ],
+ 'end_transaction_column_name': [
+ 'end_transaction_id',
+ 'end_tx_id'
+ ]
}
diff --git a/tests/builders/test_table_builder.py b/tests/builders/test_table_builder.py
index 16323d02..a2255c83 100644
--- a/tests/builders/test_table_builder.py
+++ b/tests/builders/test_table_builder.py
@@ -3,6 +3,7 @@
import sqlalchemy as sa
from sqlalchemy_continuum import version_class
from tests import TestCase
+from pytest import mark
class TestTableBuilder(TestCase):
@@ -69,3 +70,31 @@ class Article(self.Model):
def test_takes_out_onupdate_triggers(self):
table = version_class(self.Article).__table__
assert table.c.last_update.onupdate is None
+
+@mark.skipif("os.environ.get('DB') == 'sqlite'")
+class TestTableBuilderInOtherSchema(TestCase):
+ def create_models(self):
+ class Article(self.Model):
+ __tablename__ = 'article'
+ __versioned__ = copy(self.options)
+ __table_args__ = {'schema': 'other'}
+
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ last_update = sa.Column(
+ sa.DateTime,
+ default=datetime.utcnow,
+ onupdate=datetime.utcnow,
+ nullable=False
+ )
+ self.Article = Article
+
+ def create_tables(self):
+ self.connection.execute('DROP SCHEMA IF EXISTS other')
+ self.connection.execute('CREATE SCHEMA other')
+ TestCase.create_tables(self)
+
+ def test_created_tables_retain_schema(self):
+ table = version_class(self.Article).__table__
+ assert table.schema is not None
+ assert table.schema == self.Article.__table__.schema
+
diff --git a/tests/dialects/__init__.py b/tests/dialects/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/dialects/test_triggers.py b/tests/dialects/test_triggers.py
new file mode 100644
index 00000000..a9f96d6d
--- /dev/null
+++ b/tests/dialects/test_triggers.py
@@ -0,0 +1,61 @@
+import os
+
+import pytest
+import sqlalchemy as sa
+
+from sqlalchemy_continuum.dialects.postgresql import (
+ drop_trigger,
+ sync_trigger
+)
+from tests import (
+ get_dns_from_driver,
+ get_driver_name,
+ QueryPool,
+ uses_native_versioning
+)
+
+
+@pytest.mark.skipif('not uses_native_versioning()')
+class TestTriggerSyncing(object):
+ def setup_method(self, method):
+ driver = os.environ.get('DB', 'sqlite')
+ self.driver = get_driver_name(driver)
+ self.engine = sa.create_engine(get_dns_from_driver(self.driver))
+ self.connection = self.engine.connect()
+ if driver == 'postgres-native':
+ self.connection.execute('CREATE EXTENSION IF NOT EXISTS hstore')
+
+ self.connection.execute(
+ 'CREATE TABLE article '
+ '(id INT PRIMARY KEY, name VARCHAR(200), content TEXT)'
+ )
+ self.connection.execute(
+ 'CREATE TABLE article_version '
+ '(id INT, transaction_id INT, name VARCHAR(200), '
+ 'name_mod BOOLEAN, PRIMARY KEY (id, transaction_id))'
+ )
+
+ def teardown_method(self, method):
+ self.connection.execute('DROP TABLE IF EXISTS article')
+ self.connection.execute('DROP TABLE IF EXISTS article_version')
+ self.engine.dispose()
+ self.connection.close()
+
+ def test_sync_triggers(self):
+ sync_trigger(self.connection, 'article_version')
+ assert (
+ 'DROP TRIGGER IF EXISTS article_trigger ON "article"'
+ in QueryPool.queries[-4]
+ )
+ assert 'DROP FUNCTION ' in QueryPool.queries[-3]
+ assert 'CREATE OR REPLACE FUNCTION ' in QueryPool.queries[-2]
+ assert 'CREATE TRIGGER ' in QueryPool.queries[-1]
+ sync_trigger(self.connection, 'article_version')
+
+ def test_drop_triggers(self):
+ drop_trigger(self.connection, 'article')
+ assert (
+ 'DROP TRIGGER IF EXISTS article_trigger ON "article"'
+ in QueryPool.queries[-2]
+ )
+ assert 'DROP FUNCTION ' in QueryPool.queries[-1]
diff --git a/tests/inheritance/test_concrete_inheritance.py b/tests/inheritance/test_concrete_inheritance.py
new file mode 100644
index 00000000..e5c82ed2
--- /dev/null
+++ b/tests/inheritance/test_concrete_inheritance.py
@@ -0,0 +1,91 @@
+from pytest import mark
+import sqlalchemy as sa
+from sqlalchemy_continuum import versioning_manager, version_class
+from tests import TestCase
+
+
+class TestConreteTableInheritance(TestCase):
+ def create_models(self):
+ class TextItem(self.Model):
+ __tablename__ = 'text_item'
+ __versioned__ = {
+ 'base_classes': (self.Model, )
+ }
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+
+ discriminator = sa.Column(
+ sa.Unicode(100)
+ )
+
+ __mapper_args__ = {
+ 'polymorphic_on': discriminator
+ }
+
+ class Article(TextItem):
+ __tablename__ = 'article'
+ __mapper_args__ = {
+ 'polymorphic_identity': u'article',
+ 'concrete': True
+ }
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ name = sa.Column(sa.Unicode(255))
+
+ class BlogPost(TextItem):
+ __tablename__ = 'blog_post'
+ __mapper_args__ = {
+ 'polymorphic_identity': u'blog_post',
+ 'concrete': True
+ }
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ title = sa.Column(sa.Unicode(255))
+
+ self.TextItem = TextItem
+ self.Article = Article
+ self.BlogPost = BlogPost
+
+ def setup_method(self, method):
+ TestCase.setup_method(self, method)
+ self.TextItemVersion = version_class(self.TextItem)
+ self.ArticleVersion = version_class(self.Article)
+ self.BlogPostVersion = version_class(self.BlogPost)
+
+ def test_inheritance(self):
+ assert issubclass(self.ArticleVersion, self.TextItemVersion)
+ assert issubclass(self.BlogPostVersion, self.TextItemVersion)
+
+ def test_version_class_map(self):
+ manager = self.TextItem.__versioning_manager__
+ assert len(manager.version_class_map.keys()) == 3
+
+ def test_each_class_has_distinct_version_class(self):
+ assert self.TextItemVersion.__table__.name == 'text_item_version'
+ assert self.ArticleVersion.__table__.name == 'article_version'
+ assert self.BlogPostVersion.__table__.name == 'blog_post_version'
+
+ @mark.skipif('True')
+ def test_each_object_has_distinct_version_class(self):
+ article = self.Article()
+ blogpost = self.BlogPost()
+ textitem = self.TextItem()
+
+ self.session.add(article)
+ self.session.add(blogpost)
+ self.session.add(textitem)
+ self.session.commit()
+
+ assert type(textitem.versions[0]) == self.TextItemVersion
+ assert type(article.versions[0]) == self.ArticleVersion
+ assert type(blogpost.versions[0]) == self.BlogPostVersion
+
+ def test_transaction_changed_entities(self):
+ article = self.Article()
+ article.name = u'Text 1'
+ self.session.add(article)
+ self.session.commit()
+ Transaction = versioning_manager.transaction_cls
+ transaction = (
+ self.session.query(Transaction)
+ .order_by(sa.sql.expression.desc(Transaction.issued_at))
+ ).first()
+ assert transaction.entity_names == [u'Article']
+ assert transaction.changed_entities
diff --git a/tests/inheritance/test_join_table_inheritance.py b/tests/inheritance/test_join_table_inheritance.py
index 15df6da8..cb598ab6 100644
--- a/tests/inheritance/test_join_table_inheritance.py
+++ b/tests/inheritance/test_join_table_inheritance.py
@@ -1,10 +1,10 @@
-from pytest import mark
+import pytest
import sqlalchemy as sa
from sqlalchemy_continuum import version_class
-from tests import TestCase
+from tests import TestCase, uses_native_versioning, create_test_cases
-class TestJoinTableInheritance(TestCase):
+class JoinTableInheritanceTestCase(TestCase):
def create_models(self):
class TextItem(self.Model):
__tablename__ = 'text_item'
@@ -13,12 +13,15 @@ class TextItem(self.Model):
}
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ name = sa.Column(sa.Unicode(255))
+
discriminator = sa.Column(
sa.Unicode(100)
)
__mapper_args__ = {
'polymorphic_on': discriminator,
+ 'with_polymorphic': '*'
}
class Article(TextItem):
@@ -27,7 +30,7 @@ class Article(TextItem):
id = sa.Column(
sa.Integer,
sa.ForeignKey(TextItem.id),
- autoincrement=True, primary_key=True
+ primary_key=True
)
class BlogPost(TextItem):
@@ -36,7 +39,7 @@ class BlogPost(TextItem):
id = sa.Column(
sa.Integer,
sa.ForeignKey(TextItem.id),
- autoincrement=True, primary_key=True
+ primary_key=True
)
self.TextItem = TextItem
@@ -49,14 +52,14 @@ def setup_method(self, method):
self.ArticleVersion = version_class(self.Article)
self.BlogPostVersion = version_class(self.BlogPost)
- def test_each_class_has_distinct_version_class(self):
+ def test_each_class_has_distinct_version_table(self):
assert self.TextItemVersion.__table__.name == 'text_item_version'
assert self.ArticleVersion.__table__.name == 'article_version'
assert self.BlogPostVersion.__table__.name == 'blog_post_version'
+
assert issubclass(self.ArticleVersion, self.TextItemVersion)
assert issubclass(self.BlogPostVersion, self.TextItemVersion)
- @mark.skipif('True')
def test_each_object_has_distinct_version_class(self):
article = self.Article()
blogpost = self.BlogPost()
@@ -67,14 +70,24 @@ def test_each_object_has_distinct_version_class(self):
self.session.add(textitem)
self.session.commit()
- assert type(textitem.versions[0]) == self.TextItemVersion
+ # assert type(textitem.versions[0]) == self.TextItemVersion
assert type(article.versions[0]) == self.ArticleVersion
assert type(blogpost.versions[0]) == self.BlogPostVersion
def test_all_tables_contain_transaction_id_column(self):
- assert 'transaction_id' in self.TextItemVersion.__table__.c
- assert 'transaction_id' in self.ArticleVersion.__table__.c
- assert 'transaction_id' in self.BlogPostVersion.__table__.c
+ tx_column = self.options['transaction_column_name']
+
+ assert tx_column in self.TextItemVersion.__table__.c
+ assert tx_column in self.ArticleVersion.__table__.c
+ assert tx_column in self.BlogPostVersion.__table__.c
+
+ def test_with_polymorphic(self):
+ article = self.Article()
+ self.session.add(article)
+ self.session.commit()
+
+ version_obj = self.session.query(self.TextItemVersion).first()
+ assert isinstance(version_obj, self.ArticleVersion)
def test_consecutive_insert_and_delete(self):
article = self.Article()
@@ -84,22 +97,109 @@ def test_consecutive_insert_and_delete(self):
self.session.commit()
def test_assign_transaction_id_to_both_parent_and_child_tables(self):
+ tx_column = self.options['transaction_column_name']
article = self.Article()
self.session.add(article)
self.session.commit()
assert self.session.execute(
- 'SELECT transaction_id FROM article_version'
- ).fetchone()[0] == 1
+ 'SELECT %s FROM article_version' % tx_column
+ ).fetchone()[0]
assert self.session.execute(
- 'SELECT transaction_id FROM text_item_version'
- ).fetchone()[0] == 1
+ 'SELECT %s FROM text_item_version' % tx_column
+ ).fetchone()[0]
def test_primary_keys(self):
+ tx_column = self.options['transaction_column_name']
table = self.TextItemVersion.__table__
assert len(table.primary_key.columns)
assert 'id' in table.primary_key.columns
- assert 'transaction_id' in table.primary_key.columns
+ assert tx_column in table.primary_key.columns
table = self.ArticleVersion.__table__
assert len(table.primary_key.columns)
assert 'id' in table.primary_key.columns
- assert 'transaction_id' in table.primary_key.columns
+ assert tx_column in table.primary_key.columns
+
+ @pytest.mark.skipif('uses_native_versioning()')
+ def test_updates_end_transaction_id_to_all_tables(self):
+ if self.options['strategy'] == 'subquery':
+ pytest.skip()
+
+ end_tx_column = self.options['end_transaction_column_name']
+ tx_column = self.options['transaction_column_name']
+ article = self.Article()
+ self.session.add(article)
+ self.session.commit()
+ article.name = u'Updated article'
+ self.session.commit()
+ assert article.versions.count() == 2
+
+ assert self.session.execute(
+ 'SELECT %s FROM text_item_version '
+ 'ORDER BY %s LIMIT 1' % (end_tx_column, tx_column)
+ ).scalar()
+ assert self.session.execute(
+ 'SELECT %s FROM article_version '
+ 'ORDER BY %s LIMIT 1' % (end_tx_column, tx_column)
+ ).scalar()
+
+
+create_test_cases(JoinTableInheritanceTestCase)
+
+
+class TestDeepJoinedTableInheritance(TestCase):
+ def create_models(self):
+ class Node(self.Model):
+ __versioned__ = {}
+ __tablename__ = 'node'
+ __mapper_args__ = dict(
+ polymorphic_on='type',
+ polymorphic_identity='node',
+ with_polymorphic='*',
+ )
+
+ id = sa.Column(sa.Integer, primary_key=True)
+ type = sa.Column(sa.String(30), nullable=False)
+
+ class Content(Node):
+ __versioned__ = {}
+ __tablename__ = 'content'
+ __mapper_args__ = {
+ 'polymorphic_identity': 'content'
+ }
+ id = sa.Column(
+ sa.Integer,
+ sa.ForeignKey('node.id'),
+ primary_key=True
+ )
+ description = sa.Column(sa.UnicodeText())
+
+ class Document(Content):
+ __versioned__ = {}
+ __tablename__ = 'document'
+ __mapper_args__ = {
+ 'polymorphic_identity': 'document'
+ }
+ id = sa.Column(
+ sa.Integer,
+ sa.ForeignKey('content.id'),
+ primary_key=True
+ )
+ body = sa.Column(sa.UnicodeText)
+
+ self.Node = Node
+ self.Content = Content
+ self.Document = Document
+
+ def test_insert(self):
+ document = self.Document()
+ self.session.add(document)
+ self.session.commit()
+ assert self.session.execute(
+ 'SELECT COUNT(1) FROM document_version'
+ ).scalar() == 1
+ assert self.session.execute(
+ 'SELECT COUNT(1) FROM content_version'
+ ).scalar() == 1
+ assert self.session.execute(
+ 'SELECT COUNT(1) FROM node_version'
+ ).scalar() == 1
diff --git a/tests/inheritance/test_multi_level_inheritance.py b/tests/inheritance/test_multi_level_inheritance.py
new file mode 100644
index 00000000..d0bd9beb
--- /dev/null
+++ b/tests/inheritance/test_multi_level_inheritance.py
@@ -0,0 +1,47 @@
+import sqlalchemy as sa
+from sqlalchemy_continuum import version_class
+from tests import TestCase
+
+
+class TestCommonBaseClass(TestCase):
+ def create_models(self):
+ class BaseModel(self.Model):
+ __tablename__ = 'base_model'
+ __versioned__ = {}
+
+ id = sa.Column(sa.Integer, primary_key=True)
+ discriminator = sa.Column(sa.String(50), index=True)
+
+ __mapper_args__ = {
+ 'polymorphic_on': discriminator,
+ 'polymorphic_identity': 'product'
+ }
+
+ class FirstLevel(BaseModel):
+ __tablename__ = 'first_level'
+
+ id = sa.Column(sa.Integer, sa.ForeignKey('base_model.id'), primary_key=True)
+
+ __mapper_args__ = {
+ 'polymorphic_identity': 'first_level'
+ }
+
+ class SecondLevel(FirstLevel):
+ __mapper_args__ = {
+ 'polymorphic_identity': 'second_level'
+ }
+
+ self.BaseModel = BaseModel
+ self.FirstLevel = FirstLevel
+ self.SecondLevel = SecondLevel
+
+ def test_sa_inheritance_with_no_distinct_table_has_right_translation_class(self):
+ class_ = version_class(self.BaseModel)
+ assert class_.__name__ == 'BaseModelVersion'
+ assert class_.__table__.name == 'base_model_version'
+ class_ = version_class(self.FirstLevel)
+ assert class_.__name__ == 'FirstLevelVersion'
+ assert class_.__table__.name == 'first_level_version'
+ class_ = version_class(self.SecondLevel)
+ assert class_.__name__ == 'SecondLevelVersion'
+ assert class_.__table__.name == 'first_level_version'
diff --git a/tests/inheritance/test_single_table_inheritance.py b/tests/inheritance/test_single_table_inheritance.py
index 311cd3f2..9b723c15 100644
--- a/tests/inheritance/test_single_table_inheritance.py
+++ b/tests/inheritance/test_single_table_inheritance.py
@@ -1,10 +1,9 @@
-from pytest import mark
import sqlalchemy as sa
from sqlalchemy_continuum import versioning_manager, version_class
-from tests import TestCase
+from tests import TestCase, create_test_cases
-class TestSingleTableInheritance(TestCase):
+class SingleTableInheritanceTestCase(TestCase):
def create_models(self):
class TextItem(self.Model):
__tablename__ = 'text_item'
@@ -19,6 +18,7 @@ class TextItem(self.Model):
__mapper_args__ = {
'polymorphic_on': discriminator,
+ 'with_polymorphic': '*'
}
class Article(TextItem):
@@ -39,24 +39,19 @@ def setup_method(self, method):
self.ArticleVersion = version_class(self.Article)
self.BlogPostVersion = version_class(self.BlogPost)
+ def test_inheritance(self):
+ assert issubclass(self.ArticleVersion, self.TextItemVersion)
+ assert issubclass(self.BlogPostVersion, self.TextItemVersion)
+
def test_version_class_map(self):
manager = self.TextItem.__versioning_manager__
assert len(manager.version_class_map.keys()) == 3
- def test_transaction_relations(self):
- tx_log = versioning_manager.transaction_cls
- assert tx_log.text_items
- assert tx_log.articles
- assert tx_log.blog_posts
-
def test_each_class_has_distinct_version_class(self):
assert self.TextItemVersion.__table__.name == 'text_item_version'
assert self.ArticleVersion.__table__.name == 'text_item_version'
assert self.BlogPostVersion.__table__.name == 'text_item_version'
- assert issubclass(self.ArticleVersion, self.TextItemVersion)
- assert issubclass(self.BlogPostVersion, self.TextItemVersion)
- @mark.skipif('True')
def test_each_object_has_distinct_version_class(self):
article = self.Article()
blogpost = self.BlogPost()
@@ -83,3 +78,6 @@ def test_transaction_changed_entities(self):
).first()
assert transaction.entity_names == [u'Article']
assert transaction.changed_entities
+
+
+create_test_cases(SingleTableInheritanceTestCase)
diff --git a/tests/plugins/test_activity.py b/tests/plugins/test_activity.py
index 50668b70..812eb542 100644
--- a/tests/plugins/test_activity.py
+++ b/tests/plugins/test_activity.py
@@ -1,7 +1,8 @@
+import pytest
import sqlalchemy as sa
from sqlalchemy_continuum import versioning_manager
from sqlalchemy_continuum.plugins import ActivityPlugin
-from tests import TestCase, QueryPool
+from tests import TestCase, QueryPool, uses_native_versioning
class ActivityTestCase(TestCase):
@@ -99,6 +100,7 @@ def test_activity_queries(self):
class TestObjectTxIdGeneration(ActivityTestCase):
+ @pytest.mark.skipif('uses_native_versioning()')
def test_does_not_query_db_if_version_obj_in_session(self):
article = self.create_article()
self.session.flush()
@@ -121,6 +123,7 @@ def test_create_activity_with_multiple_existing_objects(self):
class TestTargetTxIdGeneration(ActivityTestCase):
+ @pytest.mark.skipif('uses_native_versioning()')
def test_does_not_query_db_if_version_obj_in_session(self):
article = self.create_article()
self.session.flush()
diff --git a/tests/plugins/test_flask.py b/tests/plugins/test_flask.py
index 3771e805..b81d6a65 100644
--- a/tests/plugins/test_flask.py
+++ b/tests/plugins/test_flask.py
@@ -1,14 +1,40 @@
+import os
+
from flask import Flask, url_for
-from flask.ext.login import LoginManager
+from flask_login import LoginManager
+from flask_sqlalchemy import SQLAlchemy, _SessionSignalEvents
+from flexmock import flexmock
+
import sqlalchemy as sa
+from sqlalchemy_continuum import (
+ make_versioned, remove_versioning, versioning_manager
+)
from sqlalchemy_continuum.plugins import FlaskPlugin
from sqlalchemy_continuum.transaction import TransactionFactory
-from tests import TestCase
+from tests import (
+ TestCase,
+ get_driver_name,
+ get_dns_from_driver,
+ uses_native_versioning
+)
+
+class TestFlaskPluginConfiguration(object):
+ def test_set_factories(self):
+ some_func = lambda: None
+ some_other_func = lambda: None
+ plugin = FlaskPlugin(
+ current_user_id_factory=some_func,
+ remote_addr_factory=some_other_func
+ )
+ assert plugin.current_user_id_factory is some_func
+ assert plugin.remote_addr_factory is some_other_func
-class TestFlaskVersioningManager(TestCase):
+
+class TestFlaskPlugin(TestCase):
plugins = [FlaskPlugin()]
transaction_cls = TransactionFactory()
+ user_cls = 'User'
def setup_method(self, method):
TestCase.setup_method(self, method)
@@ -61,32 +87,56 @@ class User(self.Model):
self.User = User
def setup_views(self):
- @self.app.route('/')
- def index():
+ @self.app.route('/simple-flush')
+ def test_simple_flush():
article = self.Article()
article.name = u'Some article'
self.session.add(article)
self.session.commit()
return ''
+ @self.app.route('/raw-sql-and-flush')
+ def test_raw_sql_and_flush():
+ self.session.execute(
+ "INSERT INTO article (name) VALUES ('some article')"
+ )
+ article = self.Article()
+ article.name = u'Some article'
+ self.session.add(article)
+ self.session.flush()
+ self.session.execute(
+ "INSERT INTO article (name) VALUES ('some article')"
+ )
+ self.session.commit()
+ return ''
+
def test_versioning_inside_request(self):
user = self.User(name=u'Rambo')
self.session.add(user)
self.session.commit()
self.login(user)
- self.client.get(url_for('.index'))
+ self.client.get(url_for('.test_simple_flush'))
article = self.session.query(self.Article).first()
tx = article.versions[-1].transaction
assert tx.user.id == user.id
+ def test_raw_sql_and_flush(self):
+ user = self.User(name=u'Rambo')
+ self.session.add(user)
+ self.session.commit()
+ self.login(user)
+ self.client.get(url_for('.test_raw_sql_and_flush'))
+ assert (
+ self.session.query(versioning_manager.transaction_cls).count() == 2
+ )
+
-class TestFlaskVersioningManagerWithoutRequestContext(TestCase):
+class TestFlaskPluginWithoutRequestContext(TestCase):
plugins = [FlaskPlugin()]
+ user_cls = 'User'
def create_models(self):
- TestCase.create_models(self)
-
class User(self.Model):
__tablename__ = 'user'
__versioned__ = {
@@ -97,7 +147,139 @@ class User(self.Model):
name = sa.Column(sa.Unicode(255), nullable=False)
self.User = User
+ TestCase.create_models(self)
+
def test_versioning_outside_request(self):
user = self.User(name=u'Rambo')
self.session.add(user)
self.session.commit()
+
+
+class TestFlaskPluginWithFlaskSQLAlchemyExtension(object):
+ versioning_strategy = 'validity'
+
+ def create_models(self):
+ class User(self.db.Model):
+ __tablename__ = 'user'
+ __versioned__ = {
+ 'base_classes': (self.db.Model, )
+ }
+
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ name = sa.Column(sa.Unicode(255), nullable=False)
+
+ class Article(self.db.Model):
+ __tablename__ = 'article'
+ __versioned__ = {
+ 'base_classes': (self.db.Model, )
+ }
+
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ name = sa.Column(sa.Unicode(255))
+
+ article_tag = sa.Table(
+ 'article_tag',
+ self.db.Model.metadata,
+ sa.Column(
+ 'article_id',
+ sa.Integer,
+ sa.ForeignKey('article.id'),
+ primary_key=True,
+ ),
+ sa.Column(
+ 'tag_id',
+ sa.Integer,
+ sa.ForeignKey('tag.id'),
+ primary_key=True
+ )
+ )
+
+ class Tag(self.db.Model):
+ __tablename__ = 'tag'
+ __versioned__ = {
+ 'base_classes': (self.db.Model, )
+ }
+
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ name = sa.Column(sa.Unicode(255))
+
+ Tag.articles = sa.orm.relationship(
+ Article,
+ secondary=article_tag,
+ backref='tags'
+ )
+
+ self.User = User
+ self.Article = Article
+ self.Tag = Tag
+
+ def setup_method(self, method):
+ # Mock the event registering of Flask-SQLAlchemy. Currently there is no
+ # way of unregistering Flask-SQLAlchemy event listeners, hence the
+ # event listeners would affect other tests.
+ flexmock(_SessionSignalEvents).should_receive('register')
+
+ self.db = SQLAlchemy()
+ make_versioned()
+
+ versioning_manager.transaction_cls = TransactionFactory()
+ versioning_manager.options['native_versioning'] = (
+ uses_native_versioning()
+ )
+
+ self.create_models()
+
+ sa.orm.configure_mappers()
+
+ self.app = Flask(__name__)
+ # self.app.config['SQLALCHEMY_ECHO'] = True
+ self.app.config['SQLALCHEMY_DATABASE_URI'] = get_dns_from_driver(
+ get_driver_name(os.environ.get('DB', 'sqlite'))
+ )
+ self.db.init_app(self.app)
+ self.app.secret_key = 'secret'
+ self.app.debug = True
+ self.client = self.app.test_client()
+ self.context = self.app.test_request_context()
+ self.context.push()
+ self.db.create_all()
+
+ def teardown_method(self, method):
+ remove_versioning()
+ self.db.session.remove()
+ self.db.drop_all()
+ self.db.session.close_all()
+ self.db.engine.dispose()
+ self.context.pop()
+ self.context = None
+ self.client = None
+ self.app = None
+
+ def test_version_relations(self):
+ article = self.Article()
+ article.name = u'Some article'
+ article.content = u'Some content'
+ self.db.session.add(article)
+ self.db.session.commit()
+ assert not article.versions[0].tags
+
+ def test_single_insert(self):
+ article = self.Article()
+ article.name = u'Some article'
+ article.content = u'Some content'
+ tag = self.Tag(name=u'some tag')
+ article.tags.append(tag)
+ self.db.session.add(article)
+ self.db.session.commit()
+ assert len(article.versions[0].tags) == 1
+
+ def test_create_transaction_with_scoped_session(self):
+ article = self.Article()
+ article.name = u'Some article'
+ article.content = u'Some content'
+ self.db.session.add(article)
+ uow = versioning_manager.unit_of_work(self.db.session)
+ transaction = uow.create_transaction(self.db.session)
+ assert transaction.id
+
+
diff --git a/tests/plugins/test_null_delete.py b/tests/plugins/test_null_delete.py
index 95087aa2..cd64ffc7 100644
--- a/tests/plugins/test_null_delete.py
+++ b/tests/plugins/test_null_delete.py
@@ -1,5 +1,6 @@
+import pytest
from sqlalchemy_continuum.plugins import NullDeletePlugin
-from tests import TestCase
+from tests import TestCase, uses_native_versioning
class DeleteTestCase(TestCase):
@@ -19,6 +20,7 @@ def test_stores_operation_type(self):
assert versions[1].operation_type == 2
+@pytest.mark.skipif('uses_native_versioning()')
class TestDeleteWithoutStoreDataAtDelete(DeleteTestCase):
plugins = [NullDeletePlugin()]
diff --git a/tests/plugins/test_property_mod_tracker.py b/tests/plugins/test_property_mod_tracker.py
index e434c286..061d1d68 100644
--- a/tests/plugins/test_property_mod_tracker.py
+++ b/tests/plugins/test_property_mod_tracker.py
@@ -32,13 +32,57 @@ def test_primary_keys_not_included(self):
UserVersion = version_class(self.User)
assert 'id_mod' not in UserVersion.__table__.c
- def test_mod_properties_get_updated(self):
+ def test_mod_properties_with_insert(self):
user = self.User(name=u'John')
self.session.add(user)
self.session.commit()
assert user.versions[-1].name_mod
+ def test_mod_properties_with_update(self):
+ user = self.User(name=u'John')
+ self.session.add(user)
+ self.session.commit()
+ user.age = 14
+ self.session.commit()
+ assert user.versions[-1].age_mod
+ assert not user.versions[-1].name_mod
+
+ def test_mod_properties_with_delete(self):
+ user = self.User(name=u'John')
+ self.session.add(user)
+ self.session.commit()
+ self.session.delete(user)
+ self.session.commit()
+ UserVersion = version_class(self.User)
+ version = (
+ self.session
+ .query(UserVersion)
+ .order_by(sa.desc(UserVersion.transaction_id))
+ ).first()
+ assert version.age_mod
+ assert version.name_mod
+
+ def test_consequtive_insert_and_update(self):
+ user = self.User(name=u'John')
+ self.session.add(user)
+ self.session.flush()
+ user.age = 15
+ self.session.commit()
+ assert user.versions[-1].age_mod
+ assert user.versions[-1].name_mod
+
+ def test_consequtive_update_and_update(self):
+ user = self.User(name=u'John')
+ self.session.add(user)
+ self.session.commit()
+ user.name = u'Jack'
+ self.session.flush()
+ user.age = 15
+ self.session.commit()
+ assert user.versions[-1].age_mod
+ assert user.versions[-1].name_mod
+
class TestChangeSetWithPropertyModPlugin(TestCase):
plugins = [PropertyModTrackerPlugin()]
@@ -70,3 +114,78 @@ def test_changeset_for_update(self):
'content': [u'Some content', u'Updated content'],
'name': [u'Some article', u'Updated name']
}
+
+
+class TestWithAssociationTables(TestCase):
+ plugins = [PropertyModTrackerPlugin()]
+
+ def create_models(self):
+ class Article(self.Model):
+ __tablename__ = 'article'
+ __versioned__ = {
+ 'base_classes': (self.Model, )
+ }
+
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ name = sa.Column(sa.Unicode(255))
+
+ article_tag = sa.Table(
+ 'article_tag',
+ self.Model.metadata,
+ sa.Column(
+ 'article_id',
+ sa.Integer,
+ sa.ForeignKey('article.id'),
+ primary_key=True,
+ ),
+ sa.Column(
+ 'tag_id',
+ sa.Integer,
+ sa.ForeignKey('tag.id'),
+ primary_key=True
+ )
+ )
+
+ class Tag(self.Model):
+ __tablename__ = 'tag'
+ __versioned__ = {
+ 'base_classes': (self.Model, )
+ }
+
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ name = sa.Column(sa.Unicode(255))
+
+ Tag.articles = sa.orm.relationship(
+ Article,
+ secondary=article_tag,
+ backref='tags'
+ )
+
+ self.Article = Article
+ self.Tag = Tag
+
+ def test_each_column_generates_additional_mod_column(self):
+ ArticleVersion = version_class(self.Article)
+ assert 'name_mod' in ArticleVersion.__table__.c
+ column = ArticleVersion.__table__.c['name_mod']
+ assert not column.nullable
+ assert isinstance(column.type, sa.Boolean)
+
+
+class TestModTrackingWithRelationships(TestCase):
+ plugins = [PropertyModTrackerPlugin()]
+
+ def test_with_insert(self):
+ tag = self.Tag(article=self.Article(name=u'Some article'))
+ self.session.add(tag)
+ self.session.commit()
+ assert tag.versions[-1]
+
+ def test_with_update(self):
+ tag = self.Tag(article=self.Article(name=u'Some article'))
+ self.session.add(tag)
+ self.session.commit()
+ tag.article = None
+ self.session.commit()
+
+ assert tag.versions[-1].article_id_mod
diff --git a/tests/plugins/test_transaction_changes.py b/tests/plugins/test_transaction_changes.py
index 88a53eef..528b31fa 100644
--- a/tests/plugins/test_transaction_changes.py
+++ b/tests/plugins/test_transaction_changes.py
@@ -58,7 +58,6 @@ def test_saves_changed_entity_names(self):
tx = article.versions[0].transaction
assert tx.changes[0].entity_name == u'Article'
- assert article.versions[0].changes[0] == tx.changes[0]
def test_saves_only_modified_entity_names(self):
article = self.Article()
diff --git a/tests/relationships/test_association_table_relations.py b/tests/relationships/test_association_table_relations.py
new file mode 100644
index 00000000..81ec3739
--- /dev/null
+++ b/tests/relationships/test_association_table_relations.py
@@ -0,0 +1,61 @@
+import sqlalchemy as sa
+from sqlalchemy import PrimaryKeyConstraint
+from sqlalchemy.orm import relationship
+from tests import TestCase, create_test_cases
+
+
+class AssociationTableRelationshipsTestCase(TestCase):
+ def create_models(self):
+ super(AssociationTableRelationshipsTestCase, self).create_models()
+
+ class PublishedArticle(self.Model):
+ __tablename__ = 'published_article'
+ __table_args__ = (
+ PrimaryKeyConstraint("article_id", "author_id"),
+ {'useexisting': True}
+ )
+
+ article_id = sa.Column(sa.Integer, sa.ForeignKey('article.id'))
+ author_id = sa.Column(sa.Integer, sa.ForeignKey('author.id'))
+ author = relationship('Author')
+ article = relationship('Article')
+
+ self.PublishedArticle = PublishedArticle
+
+ published_articles_table = sa.Table(PublishedArticle.__tablename__,
+ PublishedArticle.metadata,
+ extend_existing=True)
+
+ class Author(self.Model):
+ __tablename__ = 'author'
+ __versioned__ = {
+ 'base_classes': (self.Model, )
+ }
+
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ name = sa.Column(sa.Unicode(255))
+ articles = relationship('Article', secondary=published_articles_table)
+
+ self.Author = Author
+
+ def test_version_relations(self):
+ article = self.Article()
+ name = u'Some article'
+ article.name = name
+ article.content = u'Some content'
+ self.session.add(article)
+ self.session.commit()
+ assert article.versions[0].name == name
+
+ au = self.Author(name=u'Some author')
+ self.session.add(au)
+ self.session.commit()
+
+ pa = self.PublishedArticle(article_id=article.id, author_id=au.id)
+ self.session.add(pa)
+
+ self.session.commit()
+
+
+
+create_test_cases(AssociationTableRelationshipsTestCase)
diff --git a/tests/relationships/test_custom_condition_relations.py b/tests/relationships/test_custom_condition_relations.py
index d856da86..b888e424 100644
--- a/tests/relationships/test_custom_condition_relations.py
+++ b/tests/relationships/test_custom_condition_relations.py
@@ -30,7 +30,7 @@ class Tag(self.Model):
Tag,
primaryjoin=sa.and_(
Tag.article_id == Article.id,
- Tag.category == 'primary'
+ Tag.category == u'primary'
),
)
@@ -38,7 +38,7 @@ class Tag(self.Model):
Tag,
primaryjoin=sa.and_(
Tag.article_id == Article.id,
- Tag.category == 'secondary'
+ Tag.category == u'secondary'
),
)
diff --git a/tests/relationships/test_many_to_many_relations.py b/tests/relationships/test_many_to_many_relations.py
index 497619cd..2aaa1196 100644
--- a/tests/relationships/test_many_to_many_relations.py
+++ b/tests/relationships/test_many_to_many_relations.py
@@ -1,4 +1,8 @@
+import pytest
+from pytest import mark
import sqlalchemy as sa
+from sqlalchemy_continuum import versioning_manager
+
from tests import TestCase, create_test_cases
@@ -77,6 +81,18 @@ def test_multi_insert(self):
self.session.commit()
assert len(article.versions[0].tags) == 2
+ def test_collection_with_multiple_entries(self):
+ article = self.Article()
+ article.name = u'Some article'
+ article.content = u'Some content'
+ self.session.add(article)
+ article.tags = [
+ self.Tag(name=u'some tag'),
+ self.Tag(name=u'another tag')
+ ]
+ self.session.commit()
+ assert len(article.versions[0].tags) == 2
+
def test_delete_single_association(self):
article = self.Article()
article.name = u'Some article'
@@ -137,5 +153,297 @@ def test_multiple_parent_objects_added_within_same_transaction(self):
tags = article.versions[0].tags
assert tags == [tag.versions[0]]
+ def test_relations_with_varying_transactions(self):
+ if (
+ self.driver == 'mysql' and
+ self.connection.dialect.server_version_info < (5, 6)
+ ):
+ pytest.skip()
+
+ # one article with one tag
+ article = self.Article(name=u'Some article')
+ tag1 = self.Tag(name=u'some tag')
+ article.tags.append(tag1)
+ self.session.add(article)
+ self.session.commit()
+
+ # update article and tag, add a 2nd tag
+ tag2 = self.Tag(name=u'some other tag')
+ article.tags.append(tag2)
+ tag1.name = u'updated tag1'
+ article.name = u'updated article'
+ self.session.commit()
+
+ # update article and first tag only
+ tag1.name = u'updated tag1 x2'
+ article.name = u'updated article x2'
+ self.session.commit()
+
+ assert len(article.versions[0].tags) == 1
+ assert article.versions[0].tags[0] is tag1.versions[0]
+
+ assert len(article.versions[1].tags) == 2
+ assert tag1.versions[1] in article.versions[1].tags
+ assert tag2.versions[0] in article.versions[1].tags
+
+ assert len(article.versions[2].tags) == 2
+ assert tag1.versions[2] in article.versions[2].tags
+ assert tag2.versions[0] in article.versions[2].tags
+
create_test_cases(ManyToManyRelationshipsTestCase)
+
+
+class TestManyToManyRelationshipWithViewOnly(TestCase):
+ def create_models(self):
+ class Article(self.Model):
+ __tablename__ = 'article'
+ __versioned__ = {
+ 'base_classes': (self.Model, )
+ }
+
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ name = sa.Column(sa.Unicode(255))
+
+ article_tag = sa.Table(
+ 'article_tag',
+ self.Model.metadata,
+ sa.Column(
+ 'article_id',
+ sa.Integer,
+ sa.ForeignKey('article.id'),
+ primary_key=True,
+ ),
+ sa.Column(
+ 'tag_id',
+ sa.Integer,
+ sa.ForeignKey('tag.id'),
+ primary_key=True
+ )
+ )
+
+ class Tag(self.Model):
+ __tablename__ = 'tag'
+ __versioned__ = {
+ 'base_classes': (self.Model, )
+ }
+
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ name = sa.Column(sa.Unicode(255))
+
+ Tag.articles = sa.orm.relationship(
+ Article,
+ secondary=article_tag,
+ viewonly=True
+ )
+
+ self.article_tag = article_tag
+ self.Article = Article
+ self.Tag = Tag
+
+ def test_does_not_add_association_table_to_manager_registry(self):
+ assert self.article_tag not in versioning_manager.association_tables
+
+
+class TestManyToManySelfReferential(TestCase):
+
+ def create_models(self):
+ class Article(self.Model):
+ __tablename__ = 'article'
+ __versioned__ = {}
+
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ name = sa.Column(sa.Unicode(255))
+
+ article_references = sa.Table(
+ 'article_references',
+ self.Model.metadata,
+ sa.Column(
+ 'referring_id',
+ sa.Integer,
+ sa.ForeignKey('article.id'),
+ primary_key=True,
+ ),
+ sa.Column(
+ 'referred_id',
+ sa.Integer,
+ sa.ForeignKey('article.id'),
+ primary_key=True
+ )
+ )
+
+ Article.references = sa.orm.relationship(
+ Article,
+ secondary=article_references,
+ primaryjoin=Article.id == article_references.c.referring_id,
+ secondaryjoin=Article.id == article_references.c.referred_id,
+ backref='cited_by'
+ )
+
+ self.Article = Article
+ self.referenced_articles_table = article_references
+
+
+ def test_single_insert(self):
+
+ article = self.Article(name=u'article')
+ reference1 = self.Article(name=u'referred article 1')
+ article.references.append(reference1)
+ self.session.add(article)
+ self.session.commit()
+
+ assert len(article.versions[0].references) == 1
+ assert reference1.versions[0] in article.versions[0].references
+
+ assert len(reference1.versions[0].cited_by) == 1
+ assert article.versions[0] in reference1.versions[0].cited_by
+
+
+ def test_multiple_inserts_over_multiple_transactions(self):
+ if (
+ self.driver == 'mysql' and
+ self.connection.dialect.server_version_info < (5, 6)
+ ):
+ pytest.skip()
+
+ # create 1 article with 1 reference
+ article = self.Article(name=u'article')
+ reference1 = self.Article(name=u'reference 1')
+ article.references.append(reference1)
+ self.session.add(article)
+ self.session.commit()
+
+ # update existing, add a 2nd reference
+ article.name = u'Updated article'
+ reference1.name = u'Updated reference 1'
+ reference2 = self.Article(name=u'reference 2')
+ article.references.append(reference2)
+ self.session.commit()
+
+ # update only the article and reference 1
+ article.name = u'Updated article x2'
+ reference1.name = u'Updated reference 1 x2'
+ self.session.commit()
+
+ assert len(article.versions[1].references) == 2
+ assert reference1.versions[1] in article.versions[1].references
+ assert reference2.versions[0] in article.versions[1].references
+
+ assert len(reference1.versions[1].cited_by) == 1
+ assert article.versions[1] in reference1.versions[1].cited_by
+
+ assert len(reference2.versions[0].cited_by) == 1
+ assert article.versions[1] in reference2.versions[0].cited_by
+
+ assert len(article.versions[2].references) == 2
+ assert reference1.versions[2] in article.versions[2].references
+ assert reference2.versions[0] in article.versions[2].references
+
+ assert len(reference1.versions[2].cited_by) == 1
+ assert article.versions[2] in reference1.versions[2].cited_by
+
+
+@mark.skipif("os.environ.get('DB') == 'sqlite'")
+class TestManyToManySelfReferentialInOtherSchema(TestManyToManySelfReferential):
+ def create_models(self):
+ class Article(self.Model):
+ __tablename__ = 'article'
+ __versioned__ = {}
+ __table_args__ = {'schema': 'other'}
+
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ name = sa.Column(sa.Unicode(255))
+
+ article_references = sa.Table(
+ 'article_references',
+ self.Model.metadata,
+ sa.Column(
+ 'referring_id',
+ sa.Integer,
+ sa.ForeignKey('other.article.id'),
+ primary_key=True,
+ ),
+ sa.Column(
+ 'referred_id',
+ sa.Integer,
+ sa.ForeignKey('other.article.id'),
+ primary_key=True
+ ),
+ schema='other'
+ )
+
+ Article.references = sa.orm.relationship(
+ Article,
+ secondary=article_references,
+ primaryjoin=Article.id == article_references.c.referring_id,
+ secondaryjoin=Article.id == article_references.c.referred_id,
+ backref='cited_by'
+ )
+
+ self.Article = Article
+ self.referenced_articles_table = article_references
+
+ def create_tables(self):
+ self.connection.execute('DROP SCHEMA IF EXISTS other')
+ self.connection.execute('CREATE SCHEMA other')
+ TestManyToManySelfReferential.create_tables(self)
+
+
+@mark.skipif("os.environ.get('DB') == 'sqlite'")
+class ManyToManyRelationshipsInOtherSchemaTestCase(ManyToManyRelationshipsTestCase):
+ def create_models(self):
+ class Article(self.Model):
+ __tablename__ = 'article'
+ __versioned__ = {
+ 'base_classes': (self.Model, )
+ }
+ __table_args__ = {'schema': 'other'}
+
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ name = sa.Column(sa.Unicode(255))
+
+ article_tag = sa.Table(
+ 'article_tag',
+ self.Model.metadata,
+ sa.Column(
+ 'article_id',
+ sa.Integer,
+ sa.ForeignKey('other.article.id'),
+ primary_key=True,
+ ),
+ sa.Column(
+ 'tag_id',
+ sa.Integer,
+ sa.ForeignKey('other.tag.id'),
+ primary_key=True
+ ),
+ schema='other'
+ )
+
+ class Tag(self.Model):
+ __tablename__ = 'tag'
+ __versioned__ = {
+ 'base_classes': (self.Model, )
+ }
+ __table_args__ = {'schema': 'other'}
+
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ name = sa.Column(sa.Unicode(255))
+
+ Tag.articles = sa.orm.relationship(
+ Article,
+ secondary=article_tag,
+ backref='tags'
+ )
+
+ self.Article = Article
+ self.Tag = Tag
+
+
+ def create_tables(self):
+ self.connection.execute('DROP SCHEMA IF EXISTS other')
+ self.connection.execute('CREATE SCHEMA other')
+ ManyToManyRelationshipsTestCase.create_tables(self)
+
+create_test_cases(ManyToManyRelationshipsInOtherSchemaTestCase)
+
diff --git a/tests/relationships/test_non_versioned_classes.py b/tests/relationships/test_non_versioned_classes.py
new file mode 100644
index 00000000..cf2dad53
--- /dev/null
+++ b/tests/relationships/test_non_versioned_classes.py
@@ -0,0 +1,106 @@
+from copy import copy
+from tests import TestCase
+import sqlalchemy as sa
+
+
+class TestRelationshipToNonVersionedClass(TestCase):
+ def create_models(self):
+ class User(self.Model):
+ __tablename__ = 'user'
+
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ name = sa.Column(sa.Unicode(255))
+
+ class Article(self.Model):
+ __tablename__ = 'article'
+ __versioned__ = copy(self.options)
+
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ name = sa.Column(sa.Unicode(255), nullable=False)
+ content = sa.Column(sa.UnicodeText)
+ description = sa.Column(sa.UnicodeText)
+ author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id))
+ author = sa.orm.relationship(User)
+
+ self.Article = Article
+ self.User = User
+
+ def test_single_insert(self):
+ article = self.Article()
+ article.name = u'Some article'
+ article.content = u'Some content'
+ user = self.User(name=u'Some user')
+ article.author = user
+ self.session.add(article)
+ self.session.commit()
+
+ assert isinstance(article.versions[0].author, self.User)
+
+ def test_change_relationship(self):
+ article = self.Article()
+ article.name = u'Some article'
+ article.content = u'Some content'
+ user = self.User(name=u'Some user')
+ self.session.add(article)
+ self.session.add(user)
+ self.session.commit()
+
+ assert article.versions.count() == 1
+ article.author = user
+ self.session.commit()
+ assert article.versions.count() == 2
+
+
+class TestManyToManyRelationshipToNonVersionedClass(TestCase):
+ def create_models(self):
+ class Article(self.Model):
+ __tablename__ = 'article'
+ __versioned__ = {
+ 'base_classes': (self.Model, )
+ }
+
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ name = sa.Column(sa.Unicode(255))
+
+ article_tag = sa.Table(
+ 'article_tag',
+ self.Model.metadata,
+ sa.Column(
+ 'article_id',
+ sa.Integer,
+ sa.ForeignKey('article.id'),
+ primary_key=True,
+ ),
+ sa.Column(
+ 'tag_id',
+ sa.Integer,
+ sa.ForeignKey('tag.id'),
+ primary_key=True
+ )
+ )
+
+ class Tag(self.Model):
+ __tablename__ = 'tag'
+
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ name = sa.Column(sa.Unicode(255))
+
+ Tag.articles = sa.orm.relationship(
+ Article,
+ secondary=article_tag,
+ backref='tags'
+ )
+
+ self.Article = Article
+ self.Tag = Tag
+
+ def test_single_insert(self):
+ article = self.Article()
+ article.name = u'Some article'
+ article.content = u'Some content'
+ tag = self.Tag(name=u'some tag')
+ article.tags.append(tag)
+ self.session.add(article)
+ self.session.commit()
+ assert len(article.versions[0].tags) == 1
+ assert isinstance(article.versions[0].tags[0], self.Tag)
diff --git a/tests/relationships/test_one_to_many_relations.py b/tests/relationships/test_one_to_many_relations.py
index ab5764f1..3e07e3d7 100644
--- a/tests/relationships/test_one_to_many_relations.py
+++ b/tests/relationships/test_one_to_many_relations.py
@@ -1,3 +1,4 @@
+import pytest
import sqlalchemy as sa
from tests import TestCase, create_test_cases
@@ -70,6 +71,45 @@ def test_multiple_inserts_in_consecutive_transactions(self):
assert len(article.versions[0].tags) == 1
assert len(article.versions[1].tags) == 2
+ def test_children_inserts_with_varying_versions(self):
+ if (
+ self.driver == 'mysql' and
+ self.connection.dialect.server_version_info < (5, 6)
+ ):
+ pytest.skip()
+
+ # one article with one tag
+ article = self.Article()
+ article.name = u'Some article'
+ article.content = u'Some content'
+ tag = self.Tag(name=u'some tag')
+ article.tags.append(tag)
+ self.session.add(article)
+ self.session.commit()
+
+ # update the article and the tag, and add a 2nd tag
+ article.name = u'Updated article'
+ tag.name = u'updated tag'
+ tag2 = self.Tag(name=u'other tag',
+ article=article)
+ self.session.commit()
+
+ # update the article and the tag again
+ article.name = u'Updated again article'
+ tag.name = u'updated again tag'
+ self.session.commit()
+
+ assert len(article.versions[0].tags) == 1
+ assert article.versions[0].tags[0] is tag.versions[0]
+
+ assert len(article.versions[1].tags) == 2
+ assert tag.versions[1] in article.versions[1].tags
+ assert tag2.versions[0] in article.versions[1].tags
+
+ assert len(article.versions[2].tags) == 2
+ assert tag.versions[2] in article.versions[2].tags
+ assert tag2.versions[0] in article.versions[2].tags
+
def test_delete(self):
article = self.Article()
article.name = u'Some article'
@@ -126,3 +166,90 @@ def test_single_insert(self):
self.session.add(article)
self.session.commit()
assert article.versions[0].category == category.versions[0]
+
+
+class TestOneToManySelfReferential(TestCase):
+ def create_models(self):
+ class Article(self.Model):
+ __tablename__ = 'article'
+ __versioned__ = {}
+
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ name = sa.Column(sa.Unicode(255), nullable=False)
+ content = sa.Column(sa.UnicodeText)
+ description = sa.Column(sa.UnicodeText)
+
+ parent_article_id = sa.Column(sa.ForeignKey(id))
+ parent_article = sa.orm.relationship("Article",
+ remote_side=[id],
+ backref="child_articles")
+
+ self.Article = Article
+
+ def test_single_insert(self):
+ parent_article = self.Article(name=u'Some article')
+ child_article1 = self.Article(name=u'Child article1',
+ parent_article=parent_article)
+ self.session.add(parent_article)
+ self.session.commit()
+
+ assert len(parent_article.versions[0].child_articles) == 1
+ assert (
+ child_article1.versions[0] in
+ parent_article.versions[0].child_articles
+ )
+ assert (
+ child_article1.versions[0].parent_article is
+ parent_article.versions[0]
+ )
+
+ def test_multiple_inserts_over_multiple_transactions(self):
+ if (
+ self.driver == 'mysql' and
+ self.connection.dialect.server_version_info < (5, 6)
+ ):
+ pytest.skip()
+ parent_article = self.Article(name=u'Some article')
+ child_article1 = self.Article(name=u'Child article1',
+ parent_article=parent_article)
+ self.session.add(parent_article)
+ self.session.commit()
+
+ # update articles, add a 2nd child
+ parent_article.name = u'Updated article'
+ child_article1.name = u'Updated child article1'
+ child_article2 = self.Article(name=u'Child article2',
+ parent_article=parent_article)
+ self.session.commit()
+ # update the parent and 1st child
+ parent_article.name = u'Updated article x2'
+ child_article1.name = u'Updated child article1 x2'
+ self.session.commit()
+
+ assert len(parent_article.versions[1].child_articles) == 2
+ assert (
+ child_article1.versions[1] in
+ parent_article.versions[1].child_articles
+ )
+ assert (
+ child_article2.versions[0] in
+ parent_article.versions[1].child_articles
+ )
+ assert (
+ child_article1.versions[1].parent_article is
+ parent_article.versions[1]
+ )
+
+ assert len(parent_article.versions[2].child_articles) == 2
+ assert (
+ child_article1.versions[2] in
+ parent_article.versions[2].child_articles
+ )
+ assert (
+ child_article2.versions[0] in
+ parent_article.versions[2].child_articles
+ )
+ assert (
+ child_article1.versions[2].parent_article is
+ parent_article.versions[2]
+ )
diff --git a/tests/relationships/test_one_to_one_relations.py b/tests/relationships/test_one_to_one_relations.py
index bea65c4a..0a5b306a 100644
--- a/tests/relationships/test_one_to_one_relations.py
+++ b/tests/relationships/test_one_to_one_relations.py
@@ -50,21 +50,6 @@ def test_multiple_relation_versions(self):
assert article.versions[0].author == user.versions[0]
- def test_multiple_parent_and_relation_versions(self):
- article = self.Article()
- article.name = u'Some article'
- article.content = u'Some content'
- user = self.User(name=u'Some user')
- article.author = user
- self.session.add(article)
- self.session.commit()
- user.name = u'Someone else'
- self.session.commit()
-
- article.name = u'Updated article'
-
- assert article.versions[1].author == user.versions[1]
-
def test_multiple_consecutive_inserts_and_removes(self):
article = self.Article()
article.name = u'Some article'
diff --git a/tests/schema/test_update_end_transaction_id.py b/tests/schema/test_update_end_transaction_id.py
index 52d71305..0a3de188 100644
--- a/tests/schema/test_update_end_transaction_id.py
+++ b/tests/schema/test_update_end_transaction_id.py
@@ -6,41 +6,53 @@
class TestSchemaTools(TestCase):
versioning_strategy = 'validity'
- def test_something(self):
+ def _insert(self, values):
table = version_class(self.Article).__table__
- stmt = table.insert().values([
+ stmt = table.insert().values(values)
+ self.session.execute(stmt)
+
+ def test_update_end_transaction_id(self):
+ table = version_class(self.Article).__table__
+ self._insert(
{
'id': 1,
'transaction_id': 1,
'name': u'Article 1',
'operation_type': 1,
- },
+ }
+ )
+ self._insert(
{
'id': 1,
'transaction_id': 2,
'name': u'Article 1 updated',
'operation_type': 2,
- },
+ }
+ )
+ self._insert(
{
'id': 2,
'transaction_id': 3,
'name': u'Article 2',
'operation_type': 1,
- },
+ }
+ )
+ self._insert(
{
'id': 1,
'transaction_id': 4,
'name': u'Article 1 updated (again)',
'operation_type': 2,
- },
+ }
+ )
+ self._insert(
{
'id': 2,
'transaction_id': 5,
'name': u'Article 2 updated',
'operation_type': 2,
- },
- ])
- self.session.execute(stmt)
+ }
+ )
update_end_tx_column(table, conn=self.session)
rows = self.session.execute(
diff --git a/tests/schema/test_update_property_mod_flags.py b/tests/schema/test_update_property_mod_flags.py
index a2adac60..50d705ce 100644
--- a/tests/schema/test_update_property_mod_flags.py
+++ b/tests/schema/test_update_property_mod_flags.py
@@ -21,9 +21,14 @@ class Article(self.Model):
self.Article = Article
+ def _insert(self, values):
+ table = version_class(self.Article).__table__
+ stmt = table.insert().values(values)
+ self.session.execute(stmt)
+
def test_something(self):
table = version_class(self.Article).__table__
- stmt = table.insert().values([
+ self._insert(
{
'id': 1,
'transaction_id': 1,
@@ -31,7 +36,9 @@ def test_something(self):
'name': u'Article 1',
'name_mod': False,
'operation_type': 1,
- },
+ }
+ )
+ self._insert(
{
'id': 1,
'transaction_id': 2,
@@ -39,7 +46,9 @@ def test_something(self):
'name': u'Article 1',
'name_mod': False,
'operation_type': 2,
- },
+ }
+ )
+ self._insert(
{
'id': 2,
'transaction_id': 3,
@@ -47,7 +56,9 @@ def test_something(self):
'name': u'Article 2',
'name_mod': False,
'operation_type': 1,
- },
+ }
+ )
+ self._insert(
{
'id': 1,
'transaction_id': 4,
@@ -55,7 +66,9 @@ def test_something(self):
'name': u'Article 1 updated',
'name_mod': False,
'operation_type': 2,
- },
+ }
+ )
+ self._insert(
{
'id': 2,
'transaction_id': 5,
@@ -63,9 +76,8 @@ def test_something(self):
'name': u'Article 2',
'name_mod': False,
'operation_type': 2,
- },
- ])
- self.session.execute(stmt)
+ }
+ )
update_property_mod_flags(
table,
diff --git a/tests/test_changeset.py b/tests/test_changeset.py
index e8353348..757ff3b7 100644
--- a/tests/test_changeset.py
+++ b/tests/test_changeset.py
@@ -37,9 +37,16 @@ class ChangeSetTestCase(ChangeSetBaseTestCase):
def test_changeset_for_history_that_does_not_have_first_insert(self):
tx_log_class = get_versioning_manager(self.Article).transaction_cls
tx_log = tx_log_class(issued_at=sa.func.now())
+ if self.options['native_versioning']:
+ tx_log.id = sa.func.txid_current()
+
self.session.add(tx_log)
self.session.commit()
+ # Needed when using native versioning
+ self.session.expunge_all()
+ tx_log = self.session.query(tx_log_class).first()
+
self.session.execute(
'''INSERT INTO article_version
(id, %s, name, content, operation_type)
@@ -48,7 +55,11 @@ def test_changeset_for_history_that_does_not_have_first_insert(self):
''' % (self.transaction_column_name, tx_log.id)
)
- assert self.session.query(self.ArticleVersion).first().changeset == {}
+ assert self.session.query(self.ArticleVersion).first().changeset == {
+ 'content': [None, 'some content'],
+ 'id': [None, 1],
+ 'name': [None, 'something']
+ }
class TestChangeSetWithValidityStrategy(ChangeSetTestCase):
@@ -64,7 +75,7 @@ def create_models(self):
class Article(self.Model):
__tablename__ = 'article'
__versioned__ = {
- 'base_classes': (self.Model, )
+ 'base_classes': (self.Model,)
}
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
@@ -75,7 +86,7 @@ class Article(self.Model):
class Tag(self.Model):
__tablename__ = 'tag'
__versioned__ = {
- 'base_classes': (self.Model, )
+ 'base_classes': (self.Model,)
}
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
diff --git a/tests/test_column_aliases.py b/tests/test_column_aliases.py
index e0be1c46..53af4fa9 100644
--- a/tests/test_column_aliases.py
+++ b/tests/test_column_aliases.py
@@ -1,20 +1,32 @@
+from pytest import mark
import sqlalchemy as sa
+from sqlalchemy_continuum import version_class
from tests import TestCase, create_test_cases
-class ColumnAliasesTestCase(TestCase):
+class ColumnAliasesBaseTestCase(TestCase):
def create_models(self):
class TextItem(self.Model):
__tablename__ = 'text_item'
__versioned__ = {}
- id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ id = sa.Column(
+ '_id', sa.Integer, autoincrement=True, primary_key=True
+ )
name = sa.Column('_name', sa.Unicode(255))
self.TextItem = TextItem
+
+@mark.skipif('True')
+class TestVersionTableWithColumnAliases(ColumnAliasesBaseTestCase):
+ def test_column_reflection(self):
+ assert '_id' in version_class(self.TextItem).__table__.c
+
+
+class ColumnAliasesTestCase(ColumnAliasesBaseTestCase):
def test_insert(self):
item = self.TextItem(name=u'Something')
self.session.add(item)
@@ -30,5 +42,26 @@ def test_revert(self):
item.versions[0].revert()
self.session.commit()
+ def test_previous_for_deleted_parent(self):
+ item = self.TextItem()
+ item.name = u'Some item'
+ item.content = u'Some content'
+ self.session.add(item)
+ self.session.commit()
+ self.session.delete(item)
+ self.session.commit()
+ TextItemVersion = version_class(self.TextItem)
+
+ versions = (
+ self.session.query(TextItemVersion)
+ .order_by(
+ getattr(
+ TextItemVersion,
+ self.options['transaction_column_name']
+ )
+ )
+ ).all()
+ assert versions[1].previous.name == u'Some item'
+
create_test_cases(ColumnAliasesTestCase)
diff --git a/tests/test_column_inclusion_and_exclusion.py b/tests/test_column_inclusion_and_exclusion.py
index 641f2b44..e916b383 100644
--- a/tests/test_column_inclusion_and_exclusion.py
+++ b/tests/test_column_inclusion_and_exclusion.py
@@ -1,86 +1,74 @@
-from datetime import datetime
-from pytest import mark
import sqlalchemy as sa
-from sqlalchemy_utils import TSVectorType
from sqlalchemy_continuum import version_class
from tests import TestCase
-class TestDateTimeColumnExclusion(TestCase):
- def create_models(self):
- class Article(self.Model):
- __tablename__ = 'article'
- __versioned__ = {}
- id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
- name = sa.Column(sa.Unicode(255))
- created_at = sa.Column(sa.DateTime, default=datetime.now)
- creation_date = sa.Column(
- sa.Date, default=lambda: datetime.now().date
- )
- is_deleted = sa.Column(sa.Boolean, default=False)
-
- self.Article = Article
-
- def test_datetime_columns_with_defaults_excluded_by_default(self):
- assert (
- 'created_at' not in
- version_class(self.Article).__table__.c
- )
-
- def test_date_columns_with_defaults_excluded_by_default(self):
- assert (
- 'creation_date' not in
- version_class(self.Article).__table__.c
- )
-
- def test_datetime_exclusion_only_applies_to_datetime_types(self):
- assert (
- 'is_deleted' in
- version_class(self.Article).__table__.c
- )
-
-
-@mark.skipif("os.environ.get('DB') != 'postgres'")
-class TestTSVectorTypeColumnExclusion(TestCase):
+class ColumnExclusionTestCase(TestCase):
+ def test_excluded_columns_not_included_in_version_class(self):
+ cls = version_class(self.TextItem)
+ manager = cls._sa_class_manager
+ assert 'content' not in manager.keys()
+
+ def test_versioning_with_column_exclusion(self):
+ item = self.TextItem(name=u'Some textitem', content=u'Some content')
+ self.session.add(item)
+ self.session.commit()
+
+ assert item.versions[0].name == u'Some textitem'
+
+ def test_does_not_create_record_if_only_excluded_column_updated(self):
+ item = self.TextItem(name=u'Some textitem')
+ self.session.add(item)
+ self.session.commit()
+ item.content = u'Some content'
+ self.session.commit()
+ assert item.versions.count() == 1
+
+
+class TestColumnExclusion(ColumnExclusionTestCase):
def create_models(self):
- class Article(self.Model):
- __tablename__ = 'article'
- __versioned__ = {}
+ class TextItem(self.Model):
+ __tablename__ = 'text_item'
+ __versioned__ = {
+ 'exclude': ['content']
+ }
+
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))
- search_vector = sa.Column(TSVectorType)
-
- self.Article = Article
+ content = sa.Column(sa.UnicodeText)
- def test_tsvector_typed_columns_excluded_by_default(self):
- assert (
- 'search_vector' not in
- version_class(self.Article).__table__.c
- )
+ self.TextItem = TextItem
-class TestDateTimeColumnInclusion(TestCase):
+class TestColumnExclusionWithAliasedColumn(ColumnExclusionTestCase):
def create_models(self):
- class Article(self.Model):
- __tablename__ = 'article'
+ class TextItem(self.Model):
+ __tablename__ = 'text_item'
__versioned__ = {
- 'include': 'created_at'
+ 'exclude': ['content']
}
+
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))
- created_at = sa.Column(sa.DateTime, default=datetime.now)
-
- self.Article = Article
+ content = sa.Column('_content', sa.UnicodeText)
- def test_datetime_columns_with_defaults_excluded_by_default(self):
- assert (
- 'created_at' in
- version_class(self.Article).__table__.c
- )
+ self.TextItem = TextItem
-class TestColumnExclusion(TestCase):
+class TestColumnExclusionWithRelationship(TestCase):
def create_models(self):
+
+ class Word(self.Model):
+ __tablename__ = 'word'
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ word = sa.Column(sa.Unicode(255))
+
+ class TextItemWord(self.Model):
+ __tablename__ = 'text_item_word'
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ text_item_id = sa.Column(sa.Integer, sa.ForeignKey('text_item.id'), nullable=False)
+ word_id = sa.Column(sa.Integer, sa.ForeignKey('word.id'), nullable=False)
+
class TextItem(self.Model):
__tablename__ = 'text_item'
__versioned__ = {
@@ -89,9 +77,10 @@ class TextItem(self.Model):
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))
- content = sa.Column(sa.UnicodeText)
+ content = sa.orm.relationship(Word, secondary='text_item_word')
self.TextItem = TextItem
+ self.Word = Word
def test_excluded_columns_not_included_in_version_class(self):
cls = version_class(self.TextItem)
@@ -99,8 +88,17 @@ def test_excluded_columns_not_included_in_version_class(self):
assert 'content' not in manager.keys()
def test_versioning_with_column_exclusion(self):
- item = self.TextItem(name=u'Some textitem', content=u'Some content')
+ item = self.TextItem(name=u'Some textitem',
+ content=[self.Word(word=u'bird')])
self.session.add(item)
self.session.commit()
assert item.versions[0].name == u'Some textitem'
+
+ def test_does_not_create_record_if_only_excluded_column_updated(self):
+ item = self.TextItem(name=u'Some textitem')
+ self.session.add(item)
+ self.session.commit()
+ item.content.append(self.Word(word=u'Some content'))
+ self.session.commit()
+ assert item.versions.count() == 1
diff --git a/tests/test_configuration.py b/tests/test_configuration.py
index bae5aaf9..9dd034cc 100644
--- a/tests/test_configuration.py
+++ b/tests/test_configuration.py
@@ -1,4 +1,10 @@
+from pytest import raises, skip
import sqlalchemy as sa
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy_continuum import (
+ versioning_manager, ImproperlyConfigured, TransactionFactory
+)
+
from tests import TestCase
@@ -23,3 +29,91 @@ def test_does_not_create_history_table(self):
def test_does_add_objects_to_unit_of_work(self):
self.session.add(self.TextItem())
self.session.commit()
+
+
+class TestWithUnknownUserClass(object):
+ def test_raises_improperly_configured_error(self):
+ self.Model = declarative_base()
+
+ class TextItem(self.Model):
+ __tablename__ = 'text_item'
+ __versioned__ = {}
+
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+
+ self.TextItem = TextItem
+
+ versioning_manager.user_cls = 'User'
+ versioning_manager.declarative_base = self.Model
+
+ factory = TransactionFactory()
+ with raises(ImproperlyConfigured):
+ factory(versioning_manager)
+
+
+class TestWithCreateModelsAsFalse(TestCase):
+ should_create_models = False
+
+ def create_models(self):
+ class Article(self.Model):
+ __tablename__ = 'article'
+ __versioned__ = {}
+
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ name = sa.Column(sa.Unicode(255), nullable=False)
+ content = sa.Column(sa.UnicodeText)
+ description = sa.Column(sa.UnicodeText)
+
+ class Category(self.Model):
+ __tablename__ = 'category'
+ __versioned__ = {}
+
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ name = sa.Column(sa.Unicode(255))
+ article_id = sa.Column(sa.Integer, sa.ForeignKey(Article.id))
+ article = sa.orm.relationship(
+ Article,
+ backref=sa.orm.backref(
+ 'category',
+ uselist=False
+ )
+ )
+
+ self.Article = Article
+ self.Category = Category
+
+ def test_does_not_create_models(self):
+ assert 'class' not in self.Article.__versioned__
+
+ def test_insert(self):
+ if self.options['native_versioning'] is False:
+ skip()
+ article = self.Article(name=u'Some article')
+ self.session.add(article)
+ self.session.commit()
+
+ version = dict(
+ self.session.execute('SELECT * FROM article_version')
+ .fetchone()
+ )
+ assert version['transaction_id'] > 0
+ assert version['id'] == article.id
+ assert version['name'] == u'Some article'
+
+
+class TestWithoutAnyVersionedModels(TestCase):
+ def create_models(self):
+ class Article(self.Model):
+ __tablename__ = 'article'
+
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ name = sa.Column(sa.Unicode(255), nullable=False)
+ content = sa.Column(sa.UnicodeText)
+ description = sa.Column(sa.UnicodeText)
+
+ self.Article = Article
+
+ def test_insert(self):
+ article = self.Article(name=u'Some article')
+ self.session.add(article)
+ self.session.commit()
diff --git a/tests/test_delete.py b/tests/test_delete.py
index 35c8f42e..3934d2d7 100644
--- a/tests/test_delete.py
+++ b/tests/test_delete.py
@@ -1,3 +1,4 @@
+import sqlalchemy as sa
from tests import TestCase
@@ -23,3 +24,21 @@ def test_creates_versions_on_delete(self):
assert len(versions) == 2
assert versions[1].name == u'Some article'
assert versions[1].content == u'Some content'
+
+
+class TestDeleteWithDeferredColumn(TestCase):
+ def create_models(self):
+ class TextItem(self.Model):
+ __tablename__ = 'text_item'
+ __versioned__ = {}
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ name = sa.orm.deferred(sa.Column(sa.Unicode(255)))
+
+ self.TextItem = TextItem
+
+ def test_insert_and_delete(self):
+ item = self.TextItem()
+ self.session.add(item)
+ self.session.commit()
+ self.session.delete(item)
+ self.session.commit()
diff --git a/tests/test_exotic_listener_chaining.py b/tests/test_exotic_listener_chaining.py
new file mode 100644
index 00000000..fc547311
--- /dev/null
+++ b/tests/test_exotic_listener_chaining.py
@@ -0,0 +1,31 @@
+import sqlalchemy as sa
+from sqlalchemy_continuum import versioning_manager
+from tests import TestCase
+
+
+class TestBeforeFlushListener(TestCase):
+ def setup_method(self, method):
+ @sa.event.listens_for(sa.orm.Session, 'before_flush')
+ def before_flush(session, ctx, instances):
+ for obj in session.dirty:
+ obj.name = u'Updated article'
+
+ self.before_flush = before_flush
+
+ TestCase.setup_method(self, method)
+ self.article = self.Article()
+ self.article.name = u'Some article'
+ self.article.content = u'Some content'
+ self.session.add(self.article)
+ self.session.commit()
+
+ def teardown_method(self, method):
+ TestCase.teardown_method(self, method)
+ sa.event.remove(sa.orm.Session, 'before_flush', self.before_flush)
+
+ def test_manual_tx_creation_with_no_actual_changes(self):
+ self.article.name = u'Some article'
+
+ uow = versioning_manager.unit_of_work(self.session)
+ uow.create_transaction(self.session)
+ self.session.flush()
diff --git a/tests/test_exotic_operation_combos.py b/tests/test_exotic_operation_combos.py
index e03a6f72..5a7dc013 100644
--- a/tests/test_exotic_operation_combos.py
+++ b/tests/test_exotic_operation_combos.py
@@ -11,12 +11,12 @@ def test_insert_deleted_object(self):
self.session.commit()
self.session.delete(article)
- article2 = self.Article(id=article.id, name=u'Some article')
+ article2 = self.Article(id=article.id, name=u'Some article 2')
self.session.add(article2)
self.session.commit()
assert article2.versions.count() == 2
assert article2.versions[0].operation_type == 0
- assert article2.versions[1].operation_type == 0
+ assert article2.versions[1].operation_type == 1
def test_insert_deleted_and_flushed_object(self):
article = self.Article()
diff --git a/tests/test_i18n.py b/tests/test_i18n.py
index 77418bc4..4c6ca9e0 100644
--- a/tests/test_i18n.py
+++ b/tests/test_i18n.py
@@ -1,9 +1,11 @@
import sqlalchemy as sa
from sqlalchemy_continuum import versioning_manager
from sqlalchemy_i18n import Translatable, make_translatable, translation_base
+from sqlalchemy_utils import i18n
from . import TestCase
+i18n.get_locale = lambda: 'en'
make_translatable()
@@ -25,7 +27,10 @@ class Article(self.Model, Translatable):
}
locale = 'en'
- id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ kwargs = dict(primary_key=True)
+ if self.driver != 'sqlite':
+ kwargs['autoincrement'] = True
+ id = sa.Column(sa.Integer, **kwargs)
description = sa.Column(sa.UnicodeText)
class ArticleTranslation(translation_base(Article)):
@@ -67,10 +72,8 @@ def test_history_with_many_translations(self):
self.article.description = u'Some text'
self.session.add(self.article)
- with self.article.force_locale('fi'):
- self.article.name = u'Text 1'
- with self.article.force_locale('en'):
- self.article.name = u'Text 2'
+ self.article.translations.fi.name = u'Text 1'
+ self.article.translations.en.name = u'Text 2'
self.session.commit()
diff --git a/tests/test_insert.py b/tests/test_insert.py
index 8ffe3ad6..12b0bf04 100644
--- a/tests/test_insert.py
+++ b/tests/test_insert.py
@@ -1,3 +1,6 @@
+import sqlalchemy as sa
+from sqlalchemy_continuum import count_versions, versioning_manager
+
from tests import TestCase
@@ -35,3 +38,45 @@ def test_multiple_consecutive_flushes(self):
self.session.commit()
assert article.versions.count() == 1
assert article2.versions.count() == 1
+
+
+class TestInsertWithDeferredColumn(TestCase):
+ def create_models(self):
+ class TextItem(self.Model):
+ __tablename__ = 'text_item'
+ __versioned__ = {}
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ name = sa.orm.deferred(sa.Column(sa.Unicode(255)))
+
+ self.TextItem = TextItem
+
+ def test_insert(self):
+ item = self.TextItem()
+ self.session.add(item)
+ self.session.commit()
+ assert count_versions(item) == 1
+
+
+class TestInsertNonVersionedObject(TestCase):
+ def create_models(self):
+ class TextItem(self.Model):
+ __tablename__ = 'text_item'
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ name = sa.orm.deferred(sa.Column(sa.Unicode(255)))
+
+ class Tag(self.Model):
+ __tablename__ = 'tag'
+ __versioned__ = {}
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ name = sa.orm.deferred(sa.Column(sa.Unicode(255)))
+
+ self.TextItem = TextItem
+
+ def test_does_not_create_transaction(self):
+ item = self.TextItem()
+ self.session.add(item)
+ self.session.commit()
+
+ assert self.session.query(
+ versioning_manager.transaction_cls
+ ).count() == 0
diff --git a/tests/test_mapper_args.py b/tests/test_mapper_args.py
new file mode 100644
index 00000000..356f85f1
--- /dev/null
+++ b/tests/test_mapper_args.py
@@ -0,0 +1,80 @@
+import sqlalchemy as sa
+from sqlalchemy_continuum import version_class
+from tests import TestCase
+
+
+class TestColumnPrefix(TestCase):
+ def create_models(self):
+ class TextItem(self.Model):
+ __tablename__ = 'text_item'
+ __versioned__ = {
+ 'base_classes': (self.Model, )
+ }
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+
+ name = sa.Column(sa.Unicode(255))
+
+ __mapper_args__ = {
+ 'column_prefix': '_'
+ }
+
+ self.TextItem = TextItem
+
+ def setup_method(self, method):
+ TestCase.setup_method(self, method)
+ self.TextItemVersion = version_class(self.TextItem)
+
+ def test_supports_column_prefix(self):
+ assert self.TextItemVersion._id
+ assert self.TextItem._id
+
+
+class TestOrderByWithStringArg(TestCase):
+ def create_models(self):
+ class TextItem(self.Model):
+ __tablename__ = 'text_item'
+ __versioned__ = {
+ 'base_classes': (self.Model, )
+ }
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+
+ name = sa.Column(sa.Unicode(255))
+
+ __mapper_args__ = {
+ 'order_by': 'id',
+ 'column_prefix': '_'
+ }
+
+ self.TextItem = TextItem
+
+ def setup_method(self, method):
+ TestCase.setup_method(self, method)
+ self.TextItemVersion = version_class(self.TextItem)
+
+ def test_reflects_order_by(self):
+ assert self.TextItemVersion.__mapper_args__['order_by'] == 'id'
+
+
+class TestOrderByWithInstrumentedAttribute(TestCase):
+ def create_models(self):
+ class TextItem(self.Model):
+ __tablename__ = 'text_item'
+ __versioned__ = {
+ 'base_classes': (self.Model, )
+ }
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+
+ name = sa.Column(sa.Unicode(255))
+
+ __mapper_args__ = {
+ 'order_by': id
+ }
+
+ self.TextItem = TextItem
+
+ def setup_method(self, method):
+ TestCase.setup_method(self, method)
+ self.TextItemVersion = version_class(self.TextItem)
+
+ def test_reflects_order_by(self):
+ assert 'order_by' not in self.TextItemVersion.__mapper_args__
diff --git a/tests/test_raw_sql.py b/tests/test_raw_sql.py
new file mode 100644
index 00000000..92b0896e
--- /dev/null
+++ b/tests/test_raw_sql.py
@@ -0,0 +1,30 @@
+import pytest
+from sqlalchemy_continuum import versioning_manager
+
+from tests import TestCase, uses_native_versioning
+
+
+@pytest.mark.skipif('not uses_native_versioning()')
+class TestRawSQL(TestCase):
+ def assert_has_single_transaction(self):
+ assert (
+ self.session.query(versioning_manager.transaction_cls)
+ .count() == 1
+ )
+
+ def test_flush_after_raw_insert(self):
+ self.session.execute(
+ "INSERT INTO article (name) VALUES ('some article')"
+ )
+ self.session.add(self.Article(name=u'some other article'))
+ self.session.commit()
+ self.assert_has_single_transaction()
+
+ def test_raw_insert_after_flush(self):
+ self.session.add(self.Article(name=u'some other article'))
+ self.session.flush()
+ self.session.execute(
+ "INSERT INTO article (name) VALUES ('some article')"
+ )
+ self.session.commit()
+ self.assert_has_single_transaction()
diff --git a/tests/test_revert.py b/tests/test_revert.py
index 5a8b0dc9..e37f64a3 100644
--- a/tests/test_revert.py
+++ b/tests/test_revert.py
@@ -89,6 +89,51 @@ def test_revert_version_with_one_to_many_relation(self):
assert len(article.tags) == 1
assert article.tags[0].name == u'some tag'
+ def test_with_one_to_many_relation_delete_newly_added(self):
+ article = self.Article()
+ article.name = u'Some article'
+ article.content = u'Some content'
+ article.tags.append(self.Tag(name=u'some tag'))
+ self.session.add(article)
+ self.session.commit()
+ article.name = u'Updated name'
+ article.content = u'Updated content'
+ article.tags.append(self.Tag(name=u'some other tag'))
+ self.session.add(article)
+ self.session.commit()
+ self.session.refresh(article)
+ assert len(article.tags) == 2
+ assert len(article.versions[0].tags) == 1
+ assert article.versions[0].tags[0].article
+ article.versions[0].revert(relations=['tags'])
+ self.session.commit()
+
+ assert article.name == u'Some article'
+ assert article.content == u'Some content'
+ assert len(article.tags) == 1
+ assert article.tags[0].name == u'some tag'
+
+ def test_with_one_to_many_relation_resurrect_deleted(self):
+ article = self.Article()
+ article.name = u'Some article'
+ article.content = u'Some content'
+ tag = self.Tag(name=u'some other tag')
+ article.tags.append(self.Tag(name=u'some tag'))
+ article.tags.append(tag)
+ self.session.add(article)
+ self.session.commit()
+ article.name = u'Updated name'
+ article.tags.remove(tag)
+ self.session.add(article)
+ self.session.commit()
+ self.session.refresh(article)
+ assert len(article.tags) == 1
+ assert len(article.versions[0].tags) == 2
+ article.versions[0].revert(relations=['tags'])
+ self.session.commit()
+ assert len(article.tags) == 2
+ assert article.tags[0].name == u'some tag'
+
class TestRevertWithDefaultVersioningStrategy(RevertTestCase):
pass
diff --git a/tests/test_savepoints.py b/tests/test_savepoints.py
new file mode 100644
index 00000000..8e52f3d7
--- /dev/null
+++ b/tests/test_savepoints.py
@@ -0,0 +1,45 @@
+import pytest
+
+from tests import TestCase
+
+
+class TestSavepoints(TestCase):
+ def test_flush_and_nested_rollback(self):
+ article = self.Article(name=u'Some article')
+ self.session.add(article)
+ self.session.flush()
+ self.session.begin_nested()
+ self.session.add(self.Article(name=u'Some article'))
+ article.name = u'Updated name'
+ self.session.rollback()
+ self.session.commit()
+ assert article.versions.count() == 1
+ assert article.versions[-1].name == u'Some article'
+
+ def test_partial_rollback(self):
+ article = self.Article(name=u'Some article')
+ self.session.add(article)
+ self.session.begin_nested()
+ self.session.add(self.Article(name=u'Some article'))
+ article.name = u'Updated name'
+ self.session.rollback()
+ self.session.commit()
+ assert article.versions.count() == 1
+ assert article.versions[-1].name == u'Some article'
+
+ def test_multiple_savepoints(self):
+ if self.driver == 'sqlite':
+ pytest.skip()
+
+ article = self.Article(name=u'Some article')
+ self.session.add(article)
+ self.session.flush()
+ self.session.begin_nested()
+ article.name = u'Updated name'
+ self.session.commit()
+ self.session.begin_nested()
+ article.name = u'Another article'
+ self.session.commit()
+ self.session.commit()
+ assert article.versions.count() == 1
+ assert article.versions[-1].name == u'Another article'
diff --git a/tests/test_sessions.py b/tests/test_sessions.py
index 8d799d7a..6a1fbfb0 100644
--- a/tests/test_sessions.py
+++ b/tests/test_sessions.py
@@ -18,21 +18,34 @@ def test_multiple_connections(self):
self.session.commit()
self.session2.commit()
- assert article.versions[-1].transaction_id == 1
- assert article2.versions[-1].transaction_id == 2
+ assert article.versions[-1].transaction_id
+ assert (
+ article2.versions[-1].transaction_id >
+ article.versions[-1].transaction_id
+ )
+
+ def test_connection_binded_to_engine(self):
+ self.session2 = Session(bind=self.engine)
+ article = self.Article(name=u'Session1 article')
+ self.session2.add(article)
+ self.session2.commit()
+ assert article.versions[-1].transaction_id
def test_manual_transaction_creation(self):
uow = versioning_manager.unit_of_work(self.session)
transaction = uow.create_transaction(self.session)
self.session.flush()
- assert transaction.id == 1
+ assert transaction.id
article = self.Article(name=u'Session1 article')
self.session.add(article)
self.session.flush()
- assert uow.current_transaction.id == 1
+ assert uow.current_transaction.id
+
+ self.session.commit()
+ assert article.versions[-1].transaction_id
+ def test_commit_without_objects(self):
self.session.commit()
- assert article.versions[-1].transaction_id == 1
class TestUnitOfWork(TestCase):
@@ -40,16 +53,18 @@ def test_with_session_arg(self):
uow = versioning_manager.unit_of_work(self.session)
assert isinstance(uow, UnitOfWork)
- def test_with_connection_arg(self):
- uow = versioning_manager.unit_of_work(self.session.bind)
- assert isinstance(uow, UnitOfWork)
- def test_with_entity_arg(self):
- article = self.Article()
- self.session.add(article)
- uow = versioning_manager.unit_of_work(article)
- assert isinstance(uow, UnitOfWork)
+class TestExternalTransactionSession(TestCase):
+
+ def test_session_with_external_transaction(self):
+ conn = self.engine.connect()
+ t = conn.begin()
+ session = Session(bind=conn)
+
+ article = self.Article(name=u'My Session Article')
+ session.add(article)
+ session.flush()
- def test_raises_type_error_for_unknown_type(self):
- with raises(TypeError):
- versioning_manager.unit_of_work(None)
+ session.close()
+ t.rollback()
+ conn.close()
diff --git a/tests/test_transaction.py b/tests/test_transaction.py
index 243fb553..9d9e8f1b 100644
--- a/tests/test_transaction.py
+++ b/tests/test_transaction.py
@@ -1,5 +1,7 @@
+import sqlalchemy as sa
from sqlalchemy_continuum import versioning_manager
from tests import TestCase
+from pytest import mark
class TestTransaction(TestCase):
@@ -13,9 +15,7 @@ def setup_method(self, method):
self.session.commit()
def test_relationships(self):
- tx = self.article.versions[0].transaction
- assert tx.id == self.article.versions[0].transaction_id
- assert tx.articles == [self.article.versions[0]]
+ assert self.article.versions[0].transaction
def test_only_saves_transaction_if_actual_modifications(self):
self.article.name = u'Some article'
@@ -25,3 +25,63 @@ def test_only_saves_transaction_if_actual_modifications(self):
assert self.session.query(
versioning_manager.transaction_cls
).count() == 1
+
+ def test_repr(self):
+ transaction = self.session.query(
+ versioning_manager.transaction_cls
+ ).first()
+ assert (
+ '' % (
+ transaction.id,
+ transaction.issued_at
+ ) ==
+ repr(transaction)
+ )
+
+
+class TestAssigningUserClass(TestCase):
+ user_cls = 'User'
+
+ def create_models(self):
+ class User(self.Model):
+ __tablename__ = 'user'
+ __versioned__ = {
+ 'base_classes': (self.Model, )
+ }
+
+ id = sa.Column(sa.Unicode(255), primary_key=True)
+ name = sa.Column(sa.Unicode(255), nullable=False)
+
+ self.User = User
+
+ def test_copies_primary_key_type_from_user_class(self):
+ attr = versioning_manager.transaction_cls.user_id
+ assert isinstance(attr.property.columns[0].type, sa.Unicode)
+
+
+@mark.skipif("os.environ.get('DB') == 'sqlite'")
+class TestAssigningUserClassInOtherSchema(TestCase):
+ user_cls = 'User'
+
+ def create_models(self):
+ class User(self.Model):
+ __tablename__ = 'user'
+ __versioned__ = {
+ 'base_classes': (self.Model,)
+ }
+ __table_args__ = {'schema': 'other'}
+
+ id = sa.Column(sa.Unicode(255), primary_key=True)
+ name = sa.Column(sa.Unicode(255), nullable=False)
+
+ self.User = User
+
+ def create_tables(self):
+ self.connection.execute('DROP SCHEMA IF EXISTS other')
+ self.connection.execute('CREATE SCHEMA other')
+ TestCase.create_tables(self)
+
+ def test_can_build_transaction_model(self):
+ # If create_models didn't crash this should be good
+ pass
+
diff --git a/tests/test_utils.py b/tests/test_utils.py
deleted file mode 100644
index 4b4cb829..00000000
--- a/tests/test_utils.py
+++ /dev/null
@@ -1,141 +0,0 @@
-from pytest import raises
-
-from datetime import datetime
-import sqlalchemy as sa
-from sqlalchemy_continuum import (
- changeset, transaction_class, versioning_manager
-)
-from sqlalchemy_continuum.utils import (
- get_bind,
- is_modified,
- parent_class,
- tx_column_name,
- version_class,
-)
-
-from tests import TestCase, create_test_cases
-
-
-class TestChangeSet(TestCase):
- def test_changeset_for_new_value(self):
- article = self.Article(name=u'Some article')
- assert changeset(article) == {'name': [u'Some article', None]}
-
- def test_changeset_for_deletion(self):
- article = self.Article(name=u'Some article')
- self.session.add(article)
- self.session.commit()
- self.session.delete(article)
- assert changeset(article) == {'name': [None, u'Some article']}
-
- def test_changeset_for_update(self):
- article = self.Article(name=u'Some article')
- self.session.add(article)
- self.session.commit()
- article.tags
- article.name = u'Updated article'
- assert changeset(article) == {
- 'name': [u'Updated article', u'Some article']
- }
-
-
-class TestIsModified(TestCase):
- def create_models(self):
- class Article(self.Model):
- __tablename__ = 'article'
- __versioned__ = {
- 'exclude': 'content'
- }
- id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
- name = sa.Column(sa.Unicode(255))
- created_at = sa.Column(sa.DateTime, default=datetime.now)
- content = sa.Column(sa.Unicode(255))
-
- self.Article = Article
-
- def test_included_column(self):
- article = self.Article(name=u'Some article')
- assert is_modified(article)
-
- def test_excluded_column(self):
- article = self.Article(content=u'Some content')
- assert not is_modified(article)
-
- def test_auto_assigned_datetime_exclusion(self):
- article = self.Article(created_at=datetime.now())
- assert not is_modified(article)
-
-
-class TestVersionClass(TestCase):
- def test_version_class_for_versioned_class(self):
- ArticleVersion = version_class(self.Article)
- assert ArticleVersion.__name__ == 'ArticleVersion'
-
- def test_throws_error_for_non_versioned_class(self):
- with raises(KeyError):
- parent_class(self.Article)
-
-
-class TestGetBind(TestCase):
- def test_with_session(self):
- assert get_bind(self.session) == self.connection
-
- def test_with_connection(self):
- assert get_bind(self.connection) == self.connection
-
- def test_with_model_object(self):
- article = self.Article()
- self.session.add(article)
- assert get_bind(article) == self.connection
-
- def test_with_unknown_type(self):
- with raises(TypeError):
- get_bind(None)
-
-
-class TestTransactionClass(TestCase):
- def test_with_versioned_class(self):
- assert (
- transaction_class(self.Article) ==
- versioning_manager.transaction_cls
- )
-
- def test_with_unknown_type(self):
- with raises(AttributeError):
- transaction_class(None)
-
-
-class TestParentClass(TestCase):
- def test_parent_class_for_version_class(self):
- ArticleVersion = version_class(self.Article)
- assert parent_class(ArticleVersion) == self.Article
-
- def test_throws_error_for_non_version_class(self):
- with raises(KeyError):
- parent_class(self.Article)
-
-
-setting_variants = {
- 'transaction_column_name': ['transaction_id', 'tx_id'],
-}
-
-
-class TxColumnNameTestCase(TestCase):
- def test_with_version_class(self):
- assert tx_column_name(version_class(self.Article)) == self.options[
- 'transaction_column_name'
- ]
-
- def test_with_version_obj(self):
- history_obj = version_class(self.Article)()
- assert tx_column_name(history_obj) == self.options[
- 'transaction_column_name'
- ]
-
- def test_with_versioned_class(self):
- assert tx_column_name(self.Article) == self.options[
- 'transaction_column_name'
- ]
-
-
-create_test_cases(TxColumnNameTestCase, setting_variants=setting_variants)
diff --git a/tests/test_validity_strategy.py b/tests/test_validity_strategy.py
index ae273721..2d9ea567 100644
--- a/tests/test_validity_strategy.py
+++ b/tests/test_validity_strategy.py
@@ -60,7 +60,7 @@ class TextItem(self.Model):
__tablename__ = 'text_item'
__versioned__ = {
'base_classes': (self.Model, ),
- 'strategy': 'validity'
+ 'strategy': 'validity',
}
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
diff --git a/tests/test_versions.py b/tests/test_versions.py
new file mode 100644
index 00000000..0fcfdc6e
--- /dev/null
+++ b/tests/test_versions.py
@@ -0,0 +1,24 @@
+from tests import TestCase
+
+
+class TestVersions(TestCase):
+ def test_versions_ordered_by_transaction_id(self):
+ names = [
+ u'Some article',
+ u'Update 1 article',
+ u'Update 2 article',
+ u'Update 3 article',
+ ]
+
+ article = self.Article(name=names[0])
+ self.session.add(article)
+ self.session.commit()
+ article.name = names[1]
+ self.session.commit()
+ article.name = names[2]
+ self.session.commit()
+ article.name = names[3]
+ self.session.commit()
+
+ for index, name in enumerate(names):
+ assert article.versions[index].name == name
diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/tests/utils/test_changeset.py b/tests/utils/test_changeset.py
new file mode 100644
index 00000000..1f3eb4fa
--- /dev/null
+++ b/tests/utils/test_changeset.py
@@ -0,0 +1,25 @@
+from sqlalchemy_continuum import changeset
+from tests import TestCase
+
+
+class TestChangeSet(TestCase):
+ def test_changeset_for_new_value(self):
+ article = self.Article(name=u'Some article')
+ assert changeset(article) == {'name': [u'Some article', None]}
+
+ def test_changeset_for_deletion(self):
+ article = self.Article(name=u'Some article')
+ self.session.add(article)
+ self.session.commit()
+ self.session.delete(article)
+ assert changeset(article) == {'name': [None, u'Some article']}
+
+ def test_changeset_for_update(self):
+ article = self.Article(name=u'Some article')
+ self.session.add(article)
+ self.session.commit()
+ article.tags
+ article.name = u'Updated article'
+ assert changeset(article) == {
+ 'name': [u'Updated article', u'Some article']
+ }
diff --git a/tests/utils/test_count_versions.py b/tests/utils/test_count_versions.py
new file mode 100644
index 00000000..edd4acf5
--- /dev/null
+++ b/tests/utils/test_count_versions.py
@@ -0,0 +1,30 @@
+from sqlalchemy_continuum import count_versions
+from tests import TestCase
+
+
+class TestCountVersions(TestCase):
+ def test_count_versions_without_versions(self):
+ article = self.Article(name=u'Some article')
+ assert count_versions(article) == 0
+
+ def test_count_versions_with_initial_version(self):
+ article = self.Article(name=u'Some article')
+ self.session.add(article)
+ self.session.commit()
+ assert count_versions(article) == 1
+
+ def test_count_versions_with_multiple_versions(self):
+ article = self.Article(name=u'Some article')
+ self.session.add(article)
+ self.session.commit()
+ article.name = u'Updated article'
+ self.session.commit()
+ assert count_versions(article) == 2
+
+ def test_count_versions_with_multiple_objects(self):
+ article = self.Article(name=u'Some article')
+ self.session.add(article)
+ article2 = self.Article(name=u'Some article')
+ self.session.add(article2)
+ self.session.commit()
+ assert count_versions(article) == 1
diff --git a/tests/utils/test_is_modified.py b/tests/utils/test_is_modified.py
new file mode 100644
index 00000000..1066e48e
--- /dev/null
+++ b/tests/utils/test_is_modified.py
@@ -0,0 +1,28 @@
+from datetime import datetime
+import sqlalchemy as sa
+from sqlalchemy_continuum import is_modified
+
+from tests import TestCase
+
+
+class TestIsModified(TestCase):
+ def create_models(self):
+ class Article(self.Model):
+ __tablename__ = 'article'
+ __versioned__ = {
+ 'exclude': 'content'
+ }
+ id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
+ name = sa.Column(sa.Unicode(255))
+ created_at = sa.Column(sa.DateTime, default=datetime.now)
+ content = sa.Column(sa.Unicode(255))
+
+ self.Article = Article
+
+ def test_included_column(self):
+ article = self.Article(name=u'Some article')
+ assert is_modified(article)
+
+ def test_excluded_column(self):
+ article = self.Article(content=u'Some content')
+ assert not is_modified(article)
diff --git a/tests/utils/test_parent_class.py b/tests/utils/test_parent_class.py
new file mode 100644
index 00000000..1bb1fcfa
--- /dev/null
+++ b/tests/utils/test_parent_class.py
@@ -0,0 +1,14 @@
+from pytest import raises
+from sqlalchemy_continuum import parent_class, version_class
+
+from tests import TestCase
+
+
+class TestParentClass(TestCase):
+ def test_parent_class_for_version_class(self):
+ ArticleVersion = version_class(self.Article)
+ assert parent_class(ArticleVersion) == self.Article
+
+ def test_throws_error_for_non_version_class(self):
+ with raises(KeyError):
+ parent_class(self.Article)
diff --git a/tests/utils/test_transaction_class.py b/tests/utils/test_transaction_class.py
new file mode 100644
index 00000000..2af1b17a
--- /dev/null
+++ b/tests/utils/test_transaction_class.py
@@ -0,0 +1,20 @@
+from pytest import raises
+
+from sqlalchemy_continuum import (
+ ClassNotVersioned,
+ transaction_class,
+ versioning_manager
+)
+from tests import TestCase
+
+
+class TestTransactionClass(TestCase):
+ def test_with_versioned_class(self):
+ assert (
+ transaction_class(self.Article) ==
+ versioning_manager.transaction_cls
+ )
+
+ def test_with_unknown_type(self):
+ with raises(ClassNotVersioned):
+ transaction_class(None)
diff --git a/tests/utils/test_tx_column_name.py b/tests/utils/test_tx_column_name.py
new file mode 100644
index 00000000..142742d7
--- /dev/null
+++ b/tests/utils/test_tx_column_name.py
@@ -0,0 +1,29 @@
+from sqlalchemy_continuum import tx_column_name, version_class
+
+from tests import TestCase, create_test_cases
+
+
+setting_variants = {
+ 'transaction_column_name': ['transaction_id', 'tx_id'],
+}
+
+
+class TxColumnNameTestCase(TestCase):
+ def test_with_version_class(self):
+ assert tx_column_name(version_class(self.Article)) == self.options[
+ 'transaction_column_name'
+ ]
+
+ def test_with_version_obj(self):
+ history_obj = version_class(self.Article)()
+ assert tx_column_name(history_obj) == self.options[
+ 'transaction_column_name'
+ ]
+
+ def test_with_versioned_class(self):
+ assert tx_column_name(self.Article) == self.options[
+ 'transaction_column_name'
+ ]
+
+
+create_test_cases(TxColumnNameTestCase, setting_variants=setting_variants)
diff --git a/tests/utils/test_version_class.py b/tests/utils/test_version_class.py
new file mode 100644
index 00000000..08b69a23
--- /dev/null
+++ b/tests/utils/test_version_class.py
@@ -0,0 +1,23 @@
+from pytest import raises
+from sqlalchemy_continuum import ClassNotVersioned, version_class
+from sqlalchemy_continuum.manager import VersioningManager
+from sqlalchemy_continuum.model_builder import ModelBuilder
+
+from tests import TestCase
+
+
+class TestVersionClass(TestCase):
+ def test_version_class_for_versioned_class(self):
+ ArticleVersion = version_class(self.Article)
+ assert ArticleVersion.__name__ == 'ArticleVersion'
+
+ def test_throws_error_for_non_versioned_class(self):
+ with raises(ClassNotVersioned):
+ version_class('invalid')
+
+ def test_module_name_in_class_name(self):
+ options = {'use_module_name': True}
+ vm = VersioningManager(options=options)
+ mb = ModelBuilder(vm, self.Article)
+ ArticleVersion = mb.build_model(self.Article.__table__)
+ assert ArticleVersion.__name__ == 'TestsArticleVersion'
diff --git a/tox.ini b/tox.ini
index 296cdf3f..0f2033d4 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1,5 +1,5 @@
[tox]
-envlist = py26, py27, py33
+envlist = py27, py33, py34, py35
[testenv]
commands = pip install -e ".[test]"