Skip to content

Commit

Permalink
Fix kvesteri#238: Avoid hardcoded options.table_name in version_table()
Browse files Browse the repository at this point in the history
  • Loading branch information
indiVar0508 committed Sep 15, 2022
1 parent 27897ee commit ffb3c64
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 15 deletions.
17 changes: 14 additions & 3 deletions sqlalchemy_continuum/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import sqlalchemy as sa
from sqlalchemy.orm import object_session
from sqlalchemy_utils import get_column_key
from sqlalchemy_utils import get_column_key, get_mapper

from .builder import Builder
from .fetcher import SubqueryFetcher, ValidityFetcher
Expand Down Expand Up @@ -186,16 +186,27 @@ def is_excluded_property(self, model, key):
return False
return key in self.option(model, 'exclude')

def option(self, model, name):
def option(self, model_or_table, name):
"""
Returns the option value for given model. If the option is not found
from given model falls back to default values of this manager object.
If the option is not found from this manager object either this method
throws a KeyError.
:param model: SQLAlchemy declarative object
:param model_or_table: SQLAlchemy declarative object
:param name: name of the versioning option
"""
if isinstance(model_or_table, sa.Table):
table = model_or_table
if table in self.association_tables:
return self.options[name]
if hasattr(table, 'model'):
model = table.model
else:
raise TypeError('Table %r is not versioned.' % table)
else:
model = model_or_table

if not hasattr(model, '__versioned__'):
raise TypeError('Model %r is not versioned.' % model)
try:
Expand Down
2 changes: 2 additions & 0 deletions sqlalchemy_continuum/table_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def __call__(self, extends=None):
Builds version table.
"""
self.parent_table.__versioning_manager__ = self.manager
self.parent_table.model = self.model
columns = self.columns if extends is None else []
self.manager.plugins.after_build_version_table_columns(self, columns)
version_table = sa.schema.Table(
Expand All @@ -155,4 +156,5 @@ def __call__(self, extends=None):
extend_existing=extends is not None
)
version_table.__versioning_manager__ = self.manager
version_table.model = self.model
return version_table
34 changes: 22 additions & 12 deletions sqlalchemy_continuum/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,28 +34,34 @@ def get_versioning_manager(obj_or_class_or_table):
try:
return cls_or_table.__versioning_manager__
except AttributeError:
if issubclass(cls_or_table, sa.Table):
if isinstance(cls_or_table, sa.Table):
name = 'Table "%s"' % cls_or_table.name
else:
name = cls_or_table.__name__
raise ClassNotVersioned(name)


def option(obj_or_class, option_name):
def option(obj_or_class_or_table, option_name):
"""
Return the option value of given option for given versioned object or
class.
:param obj_or_class: SQLAlchemy declarative model object or class
:param option_name: The name of an option to return
"""
if isinstance(obj_or_class, AliasedClass):
obj_or_class = sa.inspect(obj_or_class).mapper.class_
cls = obj_or_class if isclass(obj_or_class) else obj_or_class.__class__
if not hasattr(cls, '__versioned__'):
cls = parent_class(cls)
return get_versioning_manager(cls).option(
cls, option_name
if isclass(obj_or_class_or_table):
cls_or_table = obj_or_class_or_table
else:
if isinstance(obj_or_class_or_table, AliasedClass):
cls_or_table = sa.inspect(obj_or_class_or_table).mapper.class_
elif isinstance(obj_or_class_or_table, sa.Table):
cls_or_table = obj_or_class_or_table
else:
cls_or_table = obj_or_class_or_table.__class__
if isclass(cls_or_table) and not hasattr(cls_or_table, '__versioned__'):
cls_or_table = parent_class(cls_or_table)
return get_versioning_manager(cls_or_table).option(
cls_or_table, option_name
)


Expand Down Expand Up @@ -144,17 +150,21 @@ def version_table(table):
:param table: SQLAlchemy Table object
"""
try:
suffixed_table = option(table, 'table_name') % table.name
except ClassNotVersioned:
suffixed_table = table.name + '_version' # to have same behaviour and generate key error for expression_reflector
if table.schema:
return table.metadata.tables[
table.schema + '.' + table.name + '_version'
table.schema + '.' + suffixed_table
]
elif table.metadata.schema:
return table.metadata.tables[
table.metadata.schema + '.' + table.name + '_version'
table.metadata.schema + '.' + suffixed_table
]
else:
return table.metadata.tables[
table.name + '_version'
suffixed_table
]


Expand Down
90 changes: 90 additions & 0 deletions tests/utils/test_version_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import pytest
import datetime
import sqlalchemy as sa
from tests import TestCase, uses_native_versioning

from sqlalchemy_continuum.utils import version_table

class TestVersionTableDefault(TestCase):

def create_models(self):
super().create_models()

article_author_table = sa.Table(
'article_author',
self.Model.metadata,
sa.Column('article_id', sa.Integer, sa.ForeignKey('article.id'), primary_key=True, nullable=False),
sa.Column('author_id', sa.Integer, sa.ForeignKey('author.id'), primary_key=True, nullable=False),
sa.Column('created_date', sa.DateTime, nullable=False, server_default=sa.func.current_timestamp(), default=datetime.datetime.utcnow),
)

user_activity_table = sa.Table(
'user_activity',
self.Model.metadata,
sa.Column('user_id', sa.INTEGER, sa.ForeignKey('user.id'), nullable=False),
sa.Column('login_time', sa.DateTime, nullable=False),
sa.Column('logout_time', sa.DateTime, nullable=False)
)
class Author(self.Model):
__tablename__ = 'author'
__versioned__ = {
'table_name': '%s_custom'
}
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))
articles = sa.orm.relationship('Article', secondary=article_author_table, backref='author')

class User(self.Model):
__tablename__ = 'user'

id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))

self.User = User
self.Author = Author
self.article_author_table = article_author_table
self.user_activity_table = user_activity_table

def test_version_table_with_model(self):
ArticleVersionTableName = version_table(self.Article.__table__)
assert ArticleVersionTableName.fullname == 'article_version'

def test_version_table_with_association_table(self):
ArticleAuthorVersionedTableName = version_table(self.article_author_table)
assert ArticleAuthorVersionedTableName.fullname == 'article_author_version'

def test_version_table_with_model_version_attr(self):
AuthorVersionedTableName = version_table(self.Author.__table__)
assert AuthorVersionedTableName.fullname == 'author_custom'

def test_version_table_with_non_version_model(self):
with pytest.raises(KeyError):
version_table(self.User.__table__)

def test_version_table_with_non_version_table(self):
with pytest.raises(KeyError):
version_table(self.user_activity_table)

class TestVersionTableUserDefined(TestVersionTableDefault):


@property
def options(self):
return {
'create_models': self.should_create_models,
'native_versioning': uses_native_versioning(),
'base_classes': (self.Model, ),
'strategy': self.versioning_strategy,
'transaction_column_name': self.transaction_column_name,
'end_transaction_column_name': self.end_transaction_column_name,
'table_name': '%s_user_defined'
}

def test_version_table_with_model(self):
ArticleVersionTableName = version_table(self.Article.__table__)
assert ArticleVersionTableName.fullname == 'article_user_defined'

def test_version_table_with_association_table(self):
ArticleAuthorVersionedTableName = version_table(self.article_author_table)
assert ArticleAuthorVersionedTableName.fullname == 'article_author_user_defined'

0 comments on commit ffb3c64

Please sign in to comment.