Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow tables in get_versioning_manager() #299

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion sqlalchemy_continuum/table_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,15 @@ def __call__(self, extends=None):
"""
Builds version table.
"""
self.parent_table.__versioning_manager__ = self.manager
columns = self.columns if extends is None else []
self.manager.plugins.after_build_version_table_columns(self, columns)
return sa.schema.Table(
version_table = sa.schema.Table(
extends.name if extends is not None else self.table_name,
self.parent_table.metadata,
*columns,
schema=self.parent_table.schema,
extend_existing=extends is not None
)
version_table.__versioning_manager__ = self.manager
return version_table
27 changes: 19 additions & 8 deletions sqlalchemy_continuum/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,31 @@
from .exc import ClassNotVersioned


def get_versioning_manager(obj_or_class):
def get_versioning_manager(obj_or_class_or_table):
"""
Return the associated SQLAlchemy-Continuum VersioningManager for given
SQLAlchemy declarative model class or object.
SQLAlchemy declarative model class or object or table.

:param obj_or_class: SQLAlchemy declarative model object or class
:param obj_or_class_or_table: SQLAlchemy declarative model object or class or table
"""
if isinstance(obj_or_class, AliasedClass):
obj_or_class = sa.inspect(obj_or_class).mapper.class_
cls = obj_or_class if isclass(obj_or_class) else obj_or_class.__class__
if isclass(obj_or_class_or_table):
cls_or_table = obj_or_class_or_table
else:
if isinstance(obj_or_class_or_table, AliasedClass):
cls_or_table = sa.inspect(obj_or_class_or_table).mapper.class_
elif isinstance(obj_or_class_or_table, sa.Table):
cls_or_table = obj_or_class_or_table
else:
cls_or_table = obj_or_class_or_table.__class__

try:
return cls.__versioning_manager__
return cls_or_table.__versioning_manager__
except AttributeError:
raise ClassNotVersioned(cls.__name__)
if issubclass(cls_or_table, sa.Table):
name = 'Table "%s"' % cls_or_table.name
else:
name = cls_or_table.__name__
raise ClassNotVersioned(name)


def option(obj_or_class, option_name):
Expand Down
73 changes: 73 additions & 0 deletions tests/utils/test_get_versioning_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from copy import copy
from pytest import raises
import sqlalchemy as sa
from sqlalchemy_continuum import versioning_manager
from sqlalchemy_continuum.exc import ClassNotVersioned
from sqlalchemy_continuum.utils import get_versioning_manager

from tests import TestCase


class TestGetVersioningManager(TestCase):
def create_models(self):
"""
Creates many-to-many relationship between Article and Tag
Article is versioned. But Tag is not versioned
"""
class Article(self.Model):
__tablename__ = 'article'
__versioned__ = copy(self.options)

id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))

article_tag = sa.Table(
'article_tag',
self.Model.metadata,
sa.Column(
'article_id',
sa.Integer,
sa.ForeignKey('article.id', ondelete='CASCADE'),
primary_key=True,
),
sa.Column(
'tag_id',
sa.Integer,
sa.ForeignKey('tag.id', ondelete='CASCADE'),
primary_key=True
)
)

class Tag(self.Model):
__tablename__ = 'tag'

id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))
articles = sa.orm.relationship(Article, secondary=article_tag, backref='tags')

self.Article = Article
self.article_tag = article_tag
self.Tag = Tag

def test_parent_class(self):
assert get_versioning_manager(self.Article) == versioning_manager

def test_parent_table(self):
assert get_versioning_manager(self.Article.__table__) == versioning_manager

def test_version_class(self):
assert get_versioning_manager(self.ArticleVersion) == versioning_manager

def test_version_table(self):
assert get_versioning_manager(self.ArticleVersion.__table__) == versioning_manager

def test_association_table(self):
assert get_versioning_manager(self.article_tag) == versioning_manager

def test_aliased_class(self):
assert get_versioning_manager(sa.orm.aliased(self.Article)) == versioning_manager
assert get_versioning_manager(sa.orm.aliased(self.ArticleVersion)) == versioning_manager

def test_unknown_class(self):
with raises(ClassNotVersioned):
get_versioning_manager(self.Tag)