From c598770bb53cf545cafffb4a2c6a36b705f2cef1 Mon Sep 17 00:00:00 2001 From: AbdealiJK Date: Tue, 30 Aug 2022 20:09:11 +0530 Subject: [PATCH] Allow tables in get_versioning_manager() Add Table.__versioning_manager__ so we can keep track of the manager used for a table. Allow table in get_versioning_manager() Also add unittests for the function --- sqlalchemy_continuum/table_builder.py | 5 +- sqlalchemy_continuum/utils.py | 27 +++++--- tests/utils/test_get_versioning_manager.py | 73 ++++++++++++++++++++++ 3 files changed, 96 insertions(+), 9 deletions(-) create mode 100644 tests/utils/test_get_versioning_manager.py diff --git a/sqlalchemy_continuum/table_builder.py b/sqlalchemy_continuum/table_builder.py index 600d1e02..f183d6e1 100644 --- a/sqlalchemy_continuum/table_builder.py +++ b/sqlalchemy_continuum/table_builder.py @@ -141,12 +141,15 @@ def __call__(self, extends=None): """ Builds version table. """ + self.parent_table.__versioning_manager__ = self.manager columns = self.columns if extends is None else [] self.manager.plugins.after_build_version_table_columns(self, columns) - return sa.schema.Table( + version_table = sa.schema.Table( extends.name if extends is not None else self.table_name, self.parent_table.metadata, *columns, schema=self.parent_table.schema, extend_existing=extends is not None ) + version_table.__versioning_manager__ = self.manager + return version_table diff --git a/sqlalchemy_continuum/utils.py b/sqlalchemy_continuum/utils.py index efca7b6f..a87f58ab 100644 --- a/sqlalchemy_continuum/utils.py +++ b/sqlalchemy_continuum/utils.py @@ -14,20 +14,31 @@ from .exc import ClassNotVersioned -def get_versioning_manager(obj_or_class): +def get_versioning_manager(obj_or_class_or_table): """ Return the associated SQLAlchemy-Continuum VersioningManager for given - SQLAlchemy declarative model class or object. + SQLAlchemy declarative model class or object or table. - :param obj_or_class: SQLAlchemy declarative model object or class + :param obj_or_class_or_table: SQLAlchemy declarative model object or class or table """ - if isinstance(obj_or_class, AliasedClass): - obj_or_class = sa.inspect(obj_or_class).mapper.class_ - cls = obj_or_class if isclass(obj_or_class) else obj_or_class.__class__ + if isclass(obj_or_class_or_table): + cls_or_table = obj_or_class_or_table + else: + if isinstance(obj_or_class_or_table, AliasedClass): + cls_or_table = sa.inspect(obj_or_class_or_table).mapper.class_ + elif isinstance(obj_or_class_or_table, sa.Table): + cls_or_table = obj_or_class_or_table + else: + cls_or_table = obj_or_class_or_table.__class__ + try: - return cls.__versioning_manager__ + return cls_or_table.__versioning_manager__ except AttributeError: - raise ClassNotVersioned(cls.__name__) + if issubclass(cls_or_table, sa.Table): + name = 'Table "%s"' % cls_or_table.name + else: + name = cls_or_table.__name__ + raise ClassNotVersioned(name) def option(obj_or_class, option_name): diff --git a/tests/utils/test_get_versioning_manager.py b/tests/utils/test_get_versioning_manager.py new file mode 100644 index 00000000..cc8a23a6 --- /dev/null +++ b/tests/utils/test_get_versioning_manager.py @@ -0,0 +1,73 @@ +from copy import copy +from pytest import raises +import sqlalchemy as sa +from sqlalchemy_continuum import versioning_manager +from sqlalchemy_continuum.exc import ClassNotVersioned +from sqlalchemy_continuum.utils import get_versioning_manager + +from tests import TestCase + + +class TestGetVersioningManager(TestCase): + def create_models(self): + """ + Creates many-to-many relationship between Article and Tag + Article is versioned. But Tag is not versioned + """ + class Article(self.Model): + __tablename__ = 'article' + __versioned__ = copy(self.options) + + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + article_tag = sa.Table( + 'article_tag', + self.Model.metadata, + sa.Column( + 'article_id', + sa.Integer, + sa.ForeignKey('article.id', ondelete='CASCADE'), + primary_key=True, + ), + sa.Column( + 'tag_id', + sa.Integer, + sa.ForeignKey('tag.id', ondelete='CASCADE'), + primary_key=True + ) + ) + + class Tag(self.Model): + __tablename__ = 'tag' + + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + articles = sa.orm.relationship(Article, secondary=article_tag, backref='tags') + + self.Article = Article + self.article_tag = article_tag + self.Tag = Tag + + def test_parent_class(self): + assert get_versioning_manager(self.Article) == versioning_manager + + def test_parent_table(self): + assert get_versioning_manager(self.Article.__table__) == versioning_manager + + def test_version_class(self): + assert get_versioning_manager(self.ArticleVersion) == versioning_manager + + def test_version_table(self): + assert get_versioning_manager(self.ArticleVersion.__table__) == versioning_manager + + def test_association_table(self): + assert get_versioning_manager(self.article_tag) == versioning_manager + + def test_aliased_class(self): + assert get_versioning_manager(sa.orm.aliased(self.Article)) == versioning_manager + assert get_versioning_manager(sa.orm.aliased(self.ArticleVersion)) == versioning_manager + + def test_unknown_class(self): + with raises(ClassNotVersioned): + get_versioning_manager(self.Tag)