diff --git a/sqlalchemy_continuum/manager.py b/sqlalchemy_continuum/manager.py index 0e945d53..f422042c 100644 --- a/sqlalchemy_continuum/manager.py +++ b/sqlalchemy_continuum/manager.py @@ -3,7 +3,7 @@ import sqlalchemy as sa from sqlalchemy.orm import object_session -from sqlalchemy_utils import get_column_key +from sqlalchemy_utils import get_column_key, get_mapper from .builder import Builder from .fetcher import SubqueryFetcher, ValidityFetcher @@ -186,16 +186,27 @@ def is_excluded_property(self, model, key): return False return key in self.option(model, 'exclude') - def option(self, model, name): + def option(self, model_or_table, name): """ Returns the option value for given model. If the option is not found from given model falls back to default values of this manager object. If the option is not found from this manager object either this method throws a KeyError. - :param model: SQLAlchemy declarative object + :param model_or_table: SQLAlchemy declarative object :param name: name of the versioning option """ + if isinstance(model_or_table, sa.Table): + table = model_or_table + if table in self.association_tables: + return self.options[name] + try: + model = get_mapper(table).class_ + except ValueError: + raise TypeError('Table %r is not versioned.' % table) + else: + model = model_or_table + if not hasattr(model, '__versioned__'): raise TypeError('Model %r is not versioned.' % model) try: diff --git a/sqlalchemy_continuum/utils.py b/sqlalchemy_continuum/utils.py index a87f58ab..4421b0e9 100644 --- a/sqlalchemy_continuum/utils.py +++ b/sqlalchemy_continuum/utils.py @@ -34,14 +34,14 @@ def get_versioning_manager(obj_or_class_or_table): try: return cls_or_table.__versioning_manager__ except AttributeError: - if issubclass(cls_or_table, sa.Table): + if isinstance(cls_or_table, sa.Table): name = 'Table "%s"' % cls_or_table.name else: name = cls_or_table.__name__ raise ClassNotVersioned(name) -def option(obj_or_class, option_name): +def option(obj_or_class_or_table, option_name): """ Return the option value of given option for given versioned object or class. @@ -49,13 +49,19 @@ def option(obj_or_class, option_name): :param obj_or_class: SQLAlchemy declarative model object or class :param option_name: The name of an option to return """ - if isinstance(obj_or_class, AliasedClass): - obj_or_class = sa.inspect(obj_or_class).mapper.class_ - cls = obj_or_class if isclass(obj_or_class) else obj_or_class.__class__ - if not hasattr(cls, '__versioned__'): - cls = parent_class(cls) - return get_versioning_manager(cls).option( - cls, option_name + if isclass(obj_or_class_or_table): + cls_or_table = obj_or_class_or_table + else: + if isinstance(obj_or_class_or_table, AliasedClass): + cls_or_table = sa.inspect(obj_or_class_or_table).mapper.class_ + elif isinstance(obj_or_class_or_table, sa.Table): + cls_or_table = obj_or_class_or_table + else: + cls_or_table = obj_or_class_or_table.__class__ + if isclass(cls_or_table) and not hasattr(cls_or_table, '__versioned__'): + cls_or_table = parent_class(cls_or_table) + return get_versioning_manager(cls_or_table).option( + cls_or_table, option_name ) @@ -144,17 +150,21 @@ def version_table(table): :param table: SQLAlchemy Table object """ + try: + suffixed_table = option(table, 'table_name') % table.name + except ClassNotVersioned: + suffixed_table = table.name + '_version' # to have same behaviour and generate key error for expression_reflector if table.schema: return table.metadata.tables[ - table.schema + '.' + table.name + '_version' + table.schema + '.' + suffixed_table ] elif table.metadata.schema: return table.metadata.tables[ - table.metadata.schema + '.' + table.name + '_version' + table.metadata.schema + '.' + suffixed_table ] else: return table.metadata.tables[ - table.name + '_version' + suffixed_table ] diff --git a/tests/utils/test_version_table.py b/tests/utils/test_version_table.py new file mode 100644 index 00000000..72a6a9a1 --- /dev/null +++ b/tests/utils/test_version_table.py @@ -0,0 +1,90 @@ +import pytest +import datetime +import sqlalchemy as sa +from tests import TestCase, uses_native_versioning + +from sqlalchemy_continuum.utils import version_table + +class TestVersionTableDefault(TestCase): + + def create_models(self): + super().create_models() + + article_author_table = sa.Table( + 'article_author', + self.Model.metadata, + sa.Column('article_id', sa.Integer, sa.ForeignKey('article.id'), primary_key=True, nullable=False), + sa.Column('author_id', sa.Integer, sa.ForeignKey('author.id'), primary_key=True, nullable=False), + sa.Column('created_date', sa.DateTime, nullable=False, server_default=sa.func.current_timestamp(), default=datetime.datetime.utcnow), + ) + + user_activity_table = sa.Table( + 'user_activity', + self.Model.metadata, + sa.Column('user_id', sa.INTEGER, sa.ForeignKey('user.id'), nullable=False), + sa.Column('login_time', sa.DateTime, nullable=False), + sa.Column('logout_time', sa.DateTime, nullable=False) + ) + class Author(self.Model): + __tablename__ = 'author' + __versioned__ = { + 'table_name': '%s_custom' + } + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + articles = sa.orm.relationship('Article', secondary=article_author_table, backref='author') + + class User(self.Model): + __tablename__ = 'user' + + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + self.User = User + self.Author = Author + self.article_author_table = article_author_table + self.user_activity_table = user_activity_table + + def test_version_table_with_model(self): + ArticleVersionTableName = version_table(self.Article.__table__) + assert ArticleVersionTableName.fullname == 'article_version' + + def test_version_table_with_association_table(self): + ArticleAuthorVersionedTableName = version_table(self.article_author_table) + assert ArticleAuthorVersionedTableName.fullname == 'article_author_version' + + def test_version_table_with_model_version_attr(self): + AuthorVersionedTableName = version_table(self.Author.__table__) + assert AuthorVersionedTableName.fullname == 'author_custom' + + def test_version_table_with_non_version_model(self): + with pytest.raises(KeyError): + version_table(self.User.__table__) + + def test_version_table_with_non_version_table(self): + with pytest.raises(KeyError): + version_table(self.user_activity_table) + +class TestVersionTableUserDefined(TestVersionTableDefault): + + + @property + def options(self): + return { + 'create_models': self.should_create_models, + 'native_versioning': uses_native_versioning(), + 'base_classes': (self.Model, ), + 'strategy': self.versioning_strategy, + 'transaction_column_name': self.transaction_column_name, + 'end_transaction_column_name': self.end_transaction_column_name, + 'table_name': '%s_user_defined' + } + + def test_version_table_with_model(self): + ArticleVersionTableName = version_table(self.Article.__table__) + assert ArticleVersionTableName.fullname == 'article_user_defined' + + def test_version_table_with_association_table(self): + ArticleAuthorVersionedTableName = version_table(self.article_author_table) + assert ArticleAuthorVersionedTableName.fullname == 'article_author_user_defined' +