Skip to content

Commit

Permalink
Fix kvesteri#238: Avoid hardcoded options.table_name in version_table()
Browse files Browse the repository at this point in the history
  • Loading branch information
indiVar0508 committed Sep 11, 2022
1 parent eb974b8 commit 7910f83
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 17 deletions.
12 changes: 7 additions & 5 deletions sqlalchemy_continuum/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,20 +186,22 @@ def is_excluded_property(self, model, key):
return False
return key in self.option(model, 'exclude')

def option(self, model, name):
def option(self, model_or_table, name):
"""
Returns the option value for given model. If the option is not found
from given model falls back to default values of this manager object.
If the option is not found from this manager object either this method
throws a KeyError.
:param model: SQLAlchemy declarative object
:param model_or_table: SQLAlchemy declarative object
:param name: name of the versioning option
"""
if not hasattr(model, '__versioned__'):
raise TypeError('Model %r is not versioned.' % model)
if not hasattr(model_or_table, '__versioned__'):
if isinstance(model_or_table, sa.Table): # with #299 table could also have versioning_manager available
return self.options[name]
raise TypeError('Model %r is not versioned.' % model_or_table)
try:
return model.__versioned__[name]
return model_or_table.__versioned__[name]
except KeyError:
return self.options[name]

Expand Down
34 changes: 22 additions & 12 deletions sqlalchemy_continuum/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,28 +34,34 @@ def get_versioning_manager(obj_or_class_or_table):
try:
return cls_or_table.__versioning_manager__
except AttributeError:
if issubclass(cls_or_table, sa.Table):
if isinstance(cls_or_table, sa.Table):
name = 'Table "%s"' % cls_or_table.name
else:
name = cls_or_table.__name__
raise ClassNotVersioned(name)


def option(obj_or_class, option_name):
def option(obj_or_class_or_table, option_name):
"""
Return the option value of given option for given versioned object or
class.
:param obj_or_class: SQLAlchemy declarative model object or class
:param option_name: The name of an option to return
"""
if isinstance(obj_or_class, AliasedClass):
obj_or_class = sa.inspect(obj_or_class).mapper.class_
cls = obj_or_class if isclass(obj_or_class) else obj_or_class.__class__
if not hasattr(cls, '__versioned__'):
cls = parent_class(cls)
return get_versioning_manager(cls).option(
cls, option_name
if isclass(obj_or_class_or_table):
cls_or_table = obj_or_class_or_table
else:
if isinstance(obj_or_class_or_table, AliasedClass):
cls_or_table = sa.inspect(obj_or_class_or_table).mapper.class_
elif isinstance(obj_or_class_or_table, sa.Table):
cls_or_table = obj_or_class_or_table
else:
cls_or_table = obj_or_class_or_table.__class__
if isclass(cls_or_table) and not hasattr(cls_or_table, '__versioned__'): # 299 doesn't add versioned dict as part of table builder
cls_or_table = parent_class(cls_or_table)
return get_versioning_manager(cls_or_table).option(
cls_or_table, option_name
)


Expand Down Expand Up @@ -144,17 +150,21 @@ def version_table(table):
:param table: SQLAlchemy Table object
"""
try:
suffixed_table = option(table, 'table_name') % table.name
except ClassNotVersioned:
suffixed_table = table.name + '_version' # to have same behaviour and generate key error for expression_reflector
if table.schema:
return table.metadata.tables[
table.schema + '.' + table.name + '_version'
table.schema + '.' + suffixed_table
]
elif table.metadata.schema:
return table.metadata.tables[
table.metadata.schema + '.' + table.name + '_version'
table.metadata.schema + '.' + suffixed_table
]
else:
return table.metadata.tables[
table.name + '_version'
suffixed_table
]


Expand Down
73 changes: 73 additions & 0 deletions tests/utils/test_version_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import pytest
import datetime
import sqlalchemy as sa
from tests import TestCase, uses_native_versioning

from sqlalchemy_continuum.utils import version_table

class TestVersionTableDefault(TestCase):

def create_models(self):
super().create_models()

article_author_table = sa.Table(
'article_author',
self.Model.metadata,
sa.Column('article_id', sa.Integer, sa.ForeignKey('article.id'), primary_key=True, nullable=False),
sa.Column('author_id', sa.Integer, sa.ForeignKey('author.id'), primary_key=True, nullable=False),
sa.Column('created_date', sa.DateTime, nullable=False, server_default=sa.func.current_timestamp(), default=datetime.datetime.utcnow),
)

class Author(self.Model):
__tablename__ = 'author'
__versioned__ = {
'baseclass': (self.Model, )
}
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))
articles = sa.orm.relationship('Article', secondary=article_author_table, backref='author')

class User(self.Model):
__tablename__ = 'user'

id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))

self.User = User
self.Author = Author
self.article_author_table = article_author_table

def test_version_table_with_model(self):
ArticleVersionTableName = version_table(self.Article.__table__)
assert ArticleVersionTableName.fullname == 'article_version'

def test_version_table_with_association_table(self):
ArticleAuthorVersionedTableName = version_table(self.article_author_table)
assert ArticleAuthorVersionedTableName.fullname == 'article_author_version'

def test_version_table_with_non_version_model(self):
with pytest.raises(KeyError):
version_table(self.User.__table__)

class TestVersionTableUserDefined(TestVersionTableDefault):

@property
def options(self):
return {
'create_models': self.should_create_models,
'native_versioning': uses_native_versioning(),
'base_classes': (self.Model, ),
'strategy': self.versioning_strategy,
'transaction_column_name': self.transaction_column_name,
'end_transaction_column_name': self.end_transaction_column_name,
'table_name': '%s_user_defined'
}

def test_version_table_with_model(self):
ArticleVersionTableName = version_table(self.Article.__table__)
assert ArticleVersionTableName.fullname == 'article_user_defined'

def test_version_table_with_association_table(self):
ArticleAuthorVersionedTableName = version_table(self.article_author_table)
assert ArticleAuthorVersionedTableName.fullname == 'article_author_user_defined'

0 comments on commit 7910f83

Please sign in to comment.