From 33f7d3c95130383219f4a3ec3045ef5e61af736b Mon Sep 17 00:00:00 2001 From: Matt F Date: Tue, 30 Jul 2024 14:53:43 -0400 Subject: [PATCH 1/7] Updates imports Updates sqlalchemy_utils imports to point at specific packages and removes a few unused imports. This was done in hope of not initializing the whole of SA-Utils, but that requires a different approach. --- sqlalchemy_continuum/fetcher.py | 2 +- sqlalchemy_continuum/manager.py | 4 +--- sqlalchemy_continuum/operation.py | 2 +- sqlalchemy_continuum/plugins/activity.py | 5 +++-- sqlalchemy_continuum/plugins/flask.py | 2 +- sqlalchemy_continuum/table_builder.py | 1 - sqlalchemy_continuum/unit_of_work.py | 2 +- 7 files changed, 8 insertions(+), 10 deletions(-) diff --git a/sqlalchemy_continuum/fetcher.py b/sqlalchemy_continuum/fetcher.py index 689e3b54..76e3b953 100644 --- a/sqlalchemy_continuum/fetcher.py +++ b/sqlalchemy_continuum/fetcher.py @@ -1,6 +1,6 @@ import operator import sqlalchemy as sa -from sqlalchemy_utils import get_primary_keys, identity +from sqlalchemy_utils.functions.orm import get_primary_keys, identity from .utils import tx_column_name, end_tx_column_name diff --git a/sqlalchemy_continuum/manager.py b/sqlalchemy_continuum/manager.py index e115a27a..3966cfb5 100644 --- a/sqlalchemy_continuum/manager.py +++ b/sqlalchemy_continuum/manager.py @@ -1,9 +1,8 @@ -import re from functools import wraps import sqlalchemy as sa from sqlalchemy.orm import object_session -from sqlalchemy_utils import get_column_key +from sqlalchemy_utils.functions import get_column_key from .builder import Builder from .fetcher import SubqueryFetcher, ValidityFetcher @@ -454,4 +453,3 @@ def track_association_operations( 'operation_type': op, }) uow.pending_statements.append(stmt) - diff --git a/sqlalchemy_continuum/operation.py b/sqlalchemy_continuum/operation.py index 315f25c8..a817ad6c 100644 --- a/sqlalchemy_continuum/operation.py +++ b/sqlalchemy_continuum/operation.py @@ -2,7 +2,7 @@ from collections import OrderedDict import sqlalchemy as sa -from sqlalchemy_utils import identity +from sqlalchemy_utils.functions.orm import identity class Operation(object): diff --git a/sqlalchemy_continuum/plugins/activity.py b/sqlalchemy_continuum/plugins/activity.py index 10b85d3f..f5aab09b 100644 --- a/sqlalchemy_continuum/plugins/activity.py +++ b/sqlalchemy_continuum/plugins/activity.py @@ -192,7 +192,8 @@ import sqlalchemy as sa from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.inspection import inspect -from sqlalchemy_utils import JSONType, generic_relationship +from sqlalchemy_utils.generic import generic_relationship +from sqlalchemy_utils.types.json import JSONType from .base import Plugin from ..factory import ModelFactory @@ -318,7 +319,7 @@ def target_version_type(cls): class ActivityPlugin(Plugin): activity_cls = None - + def after_build_models(self, manager): self.activity_cls = ActivityFactory()(manager) manager.activity_cls = self.activity_cls diff --git a/sqlalchemy_continuum/plugins/flask.py b/sqlalchemy_continuum/plugins/flask.py index f135f8bf..b8acfa46 100644 --- a/sqlalchemy_continuum/plugins/flask.py +++ b/sqlalchemy_continuum/plugins/flask.py @@ -24,7 +24,7 @@ from flask import current_app, has_app_context, has_request_context, request except ImportError: pass -from sqlalchemy_utils import ImproperlyConfigured +from sqlalchemy_utils.exceptions import ImproperlyConfigured from .base import Plugin diff --git a/sqlalchemy_continuum/table_builder.py b/sqlalchemy_continuum/table_builder.py index 9666b2de..423aba14 100644 --- a/sqlalchemy_continuum/table_builder.py +++ b/sqlalchemy_continuum/table_builder.py @@ -1,5 +1,4 @@ import sqlalchemy as sa -from sqlalchemy_utils import get_column_key class ColumnReflector(object): diff --git a/sqlalchemy_continuum/unit_of_work.py b/sqlalchemy_continuum/unit_of_work.py index cc5839ce..0c270a33 100644 --- a/sqlalchemy_continuum/unit_of_work.py +++ b/sqlalchemy_continuum/unit_of_work.py @@ -1,7 +1,7 @@ from copy import copy import sqlalchemy as sa -from sqlalchemy_utils import get_primary_keys, identity +from sqlalchemy_utils.functions.orm import get_primary_keys, identity from .operation import Operations from .utils import ( end_tx_column_name, From 37e7e0a25d1809afe6d68161c3978369847de6b1 Mon Sep 17 00:00:00 2001 From: Matt F Date: Wed, 7 Aug 2024 16:03:59 -0400 Subject: [PATCH 2/7] Removes SQLAlchemy-Utils dependency This is a breaking change because classes copied from SQLAlchemy-Utils but exposed through this library will have a new parent module. Utility classes and methods from SQLAlchemy-Utils are all copied into `sa_utils.py` with no functional changes. Only code required by Continuum was ported. Subsequent commits will implement cleanup and other changes discussed in --- CHANGES.rst | 6 + setup.py | 1 - sqlalchemy_continuum/__init__.py | 2 +- sqlalchemy_continuum/builder.py | 2 +- sqlalchemy_continuum/fetcher.py | 3 +- sqlalchemy_continuum/manager.py | 2 +- sqlalchemy_continuum/model_builder.py | 2 +- sqlalchemy_continuum/operation.py | 3 +- sqlalchemy_continuum/plugins/activity.py | 3 +- sqlalchemy_continuum/plugins/flask.py | 2 +- .../plugins/property_mod_tracker.py | 3 +- sqlalchemy_continuum/sa_utils.py | 580 ++++++++++++++++++ sqlalchemy_continuum/unit_of_work.py | 3 +- sqlalchemy_continuum/utils.py | 6 +- 14 files changed, 601 insertions(+), 17 deletions(-) create mode 100644 sqlalchemy_continuum/sa_utils.py diff --git a/CHANGES.rst b/CHANGES.rst index d2a4f8d8..ed2990d2 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -3,6 +3,12 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Continuum release. +2.0.0 (Unreleased) +^^^^^^^^^^^^^^^^^ + +- Removed direct dependency on SQLAlchemy-Utils to improve initialization times + + 1.3.14 (2023-01-04) ^^^^^^^^^^^^^^^^^^^ diff --git a/setup.py b/setup.py index 5e6743cc..094cd2d9 100644 --- a/setup.py +++ b/setup.py @@ -61,7 +61,6 @@ def get_version(): platforms='any', install_requires=[ 'SQLAlchemy>=1.4.0', - 'SQLAlchemy-Utils>=0.41.2', ], extras_require=extras_require, classifiers=[ diff --git a/sqlalchemy_continuum/__init__.py b/sqlalchemy_continuum/__init__.py index 79e243b8..22289482 100644 --- a/sqlalchemy_continuum/__init__.py +++ b/sqlalchemy_continuum/__init__.py @@ -18,7 +18,7 @@ ) -__version__ = '1.4.2' +__version__ = '2.0.0' versioning_manager = VersioningManager() diff --git a/sqlalchemy_continuum/builder.py b/sqlalchemy_continuum/builder.py index 3c0072fb..c29eb652 100644 --- a/sqlalchemy_continuum/builder.py +++ b/sqlalchemy_continuum/builder.py @@ -3,12 +3,12 @@ from functools import wraps import sqlalchemy as sa -from sqlalchemy_utils.functions import get_declarative_base from sqlalchemy.orm.descriptor_props import ConcreteInheritedProperty from .dialects.postgresql import create_versioning_trigger_listeners from .model_builder import ModelBuilder from .relationship_builder import RelationshipBuilder +from .sa_utils import get_declarative_base from .table_builder import TableBuilder diff --git a/sqlalchemy_continuum/fetcher.py b/sqlalchemy_continuum/fetcher.py index 76e3b953..dfd9e1c7 100644 --- a/sqlalchemy_continuum/fetcher.py +++ b/sqlalchemy_continuum/fetcher.py @@ -1,6 +1,7 @@ import operator import sqlalchemy as sa -from sqlalchemy_utils.functions.orm import get_primary_keys, identity + +from .sa_utils import get_primary_keys, identity from .utils import tx_column_name, end_tx_column_name diff --git a/sqlalchemy_continuum/manager.py b/sqlalchemy_continuum/manager.py index 3966cfb5..451d15e2 100644 --- a/sqlalchemy_continuum/manager.py +++ b/sqlalchemy_continuum/manager.py @@ -2,12 +2,12 @@ import sqlalchemy as sa from sqlalchemy.orm import object_session -from sqlalchemy_utils.functions import get_column_key from .builder import Builder from .fetcher import SubqueryFetcher, ValidityFetcher from .operation import Operation from .plugins import PluginCollection +from .sa_utils import get_column_key from .transaction import TransactionFactory from .unit_of_work import UnitOfWork from .utils import is_modified, is_versioned, version_table diff --git a/sqlalchemy_continuum/model_builder.py b/sqlalchemy_continuum/model_builder.py index 3906753c..0553bbd6 100644 --- a/sqlalchemy_continuum/model_builder.py +++ b/sqlalchemy_continuum/model_builder.py @@ -2,8 +2,8 @@ import sqlalchemy as sa from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.orm import column_property -from sqlalchemy_utils.functions import get_declarative_base +from .sa_utils import get_declarative_base from .utils import adapt_columns, option from .version import VersionClassBase diff --git a/sqlalchemy_continuum/operation.py b/sqlalchemy_continuum/operation.py index a817ad6c..b6bf21bb 100644 --- a/sqlalchemy_continuum/operation.py +++ b/sqlalchemy_continuum/operation.py @@ -2,7 +2,8 @@ from collections import OrderedDict import sqlalchemy as sa -from sqlalchemy_utils.functions.orm import identity + +from .sa_utils import identity class Operation(object): diff --git a/sqlalchemy_continuum/plugins/activity.py b/sqlalchemy_continuum/plugins/activity.py index f5aab09b..ba51f494 100644 --- a/sqlalchemy_continuum/plugins/activity.py +++ b/sqlalchemy_continuum/plugins/activity.py @@ -192,11 +192,10 @@ import sqlalchemy as sa from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.inspection import inspect -from sqlalchemy_utils.generic import generic_relationship -from sqlalchemy_utils.types.json import JSONType from .base import Plugin from ..factory import ModelFactory +from ..sa_utils import JSONType, generic_relationship from ..utils import version_class, version_obj diff --git a/sqlalchemy_continuum/plugins/flask.py b/sqlalchemy_continuum/plugins/flask.py index b8acfa46..cc3ed82a 100644 --- a/sqlalchemy_continuum/plugins/flask.py +++ b/sqlalchemy_continuum/plugins/flask.py @@ -24,9 +24,9 @@ from flask import current_app, has_app_context, has_request_context, request except ImportError: pass -from sqlalchemy_utils.exceptions import ImproperlyConfigured from .base import Plugin +from ..sa_utils import ImproperlyConfigured def fetch_current_user_id(): diff --git a/sqlalchemy_continuum/plugins/property_mod_tracker.py b/sqlalchemy_continuum/plugins/property_mod_tracker.py index 498eb07f..34cc0067 100644 --- a/sqlalchemy_continuum/plugins/property_mod_tracker.py +++ b/sqlalchemy_continuum/plugins/property_mod_tracker.py @@ -16,8 +16,9 @@ from copy import copy import sqlalchemy as sa -from sqlalchemy_utils.functions import has_changes + from .base import Plugin +from ..sa_utils import has_changes from ..utils import versioned_column_properties diff --git a/sqlalchemy_continuum/sa_utils.py b/sqlalchemy_continuum/sa_utils.py new file mode 100644 index 00000000..abfee583 --- /dev/null +++ b/sqlalchemy_continuum/sa_utils.py @@ -0,0 +1,580 @@ +from collections import OrderedDict +from collections.abc import Iterable +from inspect import isclass +import json + +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql.base import ischema_names +from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.orm import ColumnProperty, class_mapper +from sqlalchemy.orm.attributes import ( + InstrumentedAttribute, + ScalarAttributeImpl, + register_attribute, +) +from sqlalchemy.orm.base import PASSIVE_OFF +from sqlalchemy.orm.interfaces import MapperProperty, PropComparator +from sqlalchemy.orm.session import _state_session +from sqlalchemy.util import set_creation_order + +try: + from sqlalchemy.dialects.postgresql import JSON + has_postgres_json = True +except ImportError: + class PostgresJSONType(sa.types.UserDefinedType): + """ + Text search vector type for postgresql. + """ + def get_col_spec(self): + return 'json' + + ischema_names['json'] = PostgresJSONType + has_postgres_json = False + + +def _get_class_registry(class_): + try: + return class_.registry._class_registry + except AttributeError: # SQLAlchemy <1.4 + return class_._decl_class_registry + + +def _get_columns(mixed): + """ + Return a collection of all Column objects for given SQLAlchemy + object. + + The type of the collection depends on the type of the object to return the + columns from. + + :: + + get_columns(User) + + get_columns(User()) + + get_columns(User.__table__) + + get_columns(User.__mapper__) + + get_columns(sa.orm.aliased(User)) + + get_columns(sa.orm.alised(User.__table__)) + + + :param mixed: + SA Table object, SA Mapper, SA declarative class, SA declarative class + instance or an alias of any of these objects + """ + if isinstance(mixed, sa.sql.selectable.Selectable): + try: + return mixed.selected_columns + except AttributeError: # SQLAlchemy <1.4 + return mixed.c + if isinstance(mixed, sa.orm.util.AliasedClass): + return sa.inspect(mixed).mapper.columns + if isinstance(mixed, sa.orm.Mapper): + return mixed.columns + if isinstance(mixed, InstrumentedAttribute): + return mixed.property.columns + if isinstance(mixed, ColumnProperty): + return mixed.columns + if isinstance(mixed, sa.Column): + return [mixed] + if not isclass(mixed): + mixed = mixed.__class__ + return sa.inspect(mixed).columns + + +class GenericAttributeImpl(ScalarAttributeImpl): + def __init__(self, *args, **kwargs): + """ + The constructor of attributes.AttributeImpl changed in SQLAlchemy 2.0.22, + adding a 'default_function' required positional argument before 'dispatch'. + This adjustment ensures compatibility across versions by inserting None for + 'default_function' in versions >= 2.0.22. + + Arguments received: (class, key, dispatch) + Required by AttributeImpl: (class, key, default_function, dispatch) + Setting None as default_function here. + """ + # Adjust for SQLAlchemy version change + sqlalchemy_version = tuple(map(int, sa.__version__.split('.'))) + if sqlalchemy_version >= (2, 0, 22): + args = (*args[:2], None, *args[2:]) + + super().__init__(*args, **kwargs) + + def get(self, state, dict_, passive=PASSIVE_OFF): + if self.key in dict_: + return dict_[self.key] + + # Retrieve the session bound to the state in order to perform + # a lazy query for the attribute. + # TODO: replace this with sa.orm.session.object_session? + session = _state_session(state) + if session is None: + # State is not bound to a session; we cannot proceed. + return None + + # Find class for discriminator. + # TODO: Perhaps optimize with some sort of lookup? + discriminator = self.get_state_discriminator(state) + target_class = _get_class_registry(state.class_).get(discriminator) + + if target_class is None: + # Unknown discriminator; return nothing. + return None + + id = self.get_state_id(state) + + try: + target = session.get(target_class, id) + except AttributeError: + # sqlalchemy 1.3 + target = session.query(target_class).get(id) + + # Return found (or not found) target. + return target + + def get_state_discriminator(self, state): + discriminator = self.parent_token.discriminator + if isinstance(discriminator, hybrid_property): + return getattr(state.obj(), discriminator.__name__) + else: + return state.attrs[discriminator.key].value + + def get_state_id(self, state): + # Lookup row with the discriminator and id. + return tuple(state.attrs[id.key].value for id in self.parent_token.id) + + def set(self, state, dict_, initiator, + passive=PASSIVE_OFF, + check_old=None, + pop=False): + + # Set us on the state. + dict_[self.key] = initiator + + if initiator is None: + # Nullify relationship args + for id in self.parent_token.id: + dict_[id.key] = None + dict_[self.parent_token.discriminator.key] = None + else: + # Get the primary key of the initiator and ensure we + # can support this assignment. + class_ = type(initiator) + mapper = class_mapper(class_) + + pk = mapper.identity_key_from_instance(initiator)[1] + + # Set the identifier and the discriminator. + discriminator = class_.__name__ + + for index, id in enumerate(self.parent_token.id): + dict_[id.key] = pk[index] + dict_[self.parent_token.discriminator.key] = discriminator + + +class GenericRelationshipProperty(MapperProperty): + """A generic form of the relationship property. + + Creates a 1 to many relationship between the parent model + and any other models using a discriminator (the table name). + + :param discriminator + Field to discriminate which model we are referring to. + :param id: + Field to point to the model we are referring to. + """ + + def __init__(self, discriminator, id, doc=None): + super().__init__() + self._discriminator_col = discriminator + self._id_cols = id + self._id = None + self._discriminator = None + self.doc = doc + + set_creation_order(self) + + def _column_to_property(self, column): + if isinstance(column, hybrid_property): + attr_key = column.__name__ + for key, attr in self.parent.all_orm_descriptors.items(): + if key == attr_key: + return attr + else: + for attr in self.parent.attrs.values(): + if isinstance(attr, ColumnProperty): + if attr.columns[0].name == column.name: + return attr + + def init(self): + def convert_strings(column): + if isinstance(column, str): + return self.parent.columns[column] + return column + + self._discriminator_col = convert_strings(self._discriminator_col) + self._id_cols = convert_strings(self._id_cols) + + if isinstance(self._id_cols, Iterable): + self._id_cols = list(map(convert_strings, self._id_cols)) + else: + self._id_cols = [self._id_cols] + + self.discriminator = self._column_to_property(self._discriminator_col) + + if self.discriminator is None: + raise ImproperlyConfigured( + 'Could not find discriminator descriptor.' + ) + + self.id = list(map(self._column_to_property, self._id_cols)) + + class Comparator(PropComparator): + def __init__(self, prop, parentmapper): + self.property = prop + self._parententity = parentmapper + + def __eq__(self, other): + discriminator = type(other).__name__ + q = self.property._discriminator_col == discriminator + other_id = identity(other) + for index, id in enumerate(self.property._id_cols): + q &= id == other_id[index] + return q + + def __ne__(self, other): + return ~(self == other) + + def is_type(self, other): + mapper = sa.inspect(other) + # Iterate through the weak sequence in order to get the actual + # mappers + class_names = [other.__name__] + class_names.extend([ + submapper.class_.__name__ + for submapper in mapper._inheriting_mappers + ]) + + return self.property._discriminator_col.in_(class_names) + + def instrument_class(self, mapper): + register_attribute( + mapper.class_, + self.key, + comparator=self.Comparator(self, mapper), + parententity=mapper, + doc=self.doc, + impl_class=GenericAttributeImpl, + parent_token=self + ) + + +class ImproperlyConfigured(Exception): + """ + SQLAlchemy-Continuum is improperly configured; normally due to usage of + a utility that depends on a missing library. + """ + + +class JSONType(sa.types.TypeDecorator): + """ + JSONType offers way of saving JSON data structures to database. On + PostgreSQL the underlying implementation of this data type is 'json' while + on other databases its simply 'text'. + + :: + + + from sqlalchemy_continuum.sa_utils import JSONType + + + class Product(Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, autoincrement=True) + name = sa.Column(sa.Unicode(50)) + details = sa.Column(JSONType) + + + product = Product() + product.details = { + 'color': 'red', + 'type': 'car', + 'max-speed': '400 mph' + } + session.commit() + """ + impl = sa.UnicodeText + hashable = False + cache_ok = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def load_dialect_impl(self, dialect): + if dialect.name == 'postgresql': + # Use the native JSON type. + if has_postgres_json: + return dialect.type_descriptor(JSON()) + else: + return dialect.type_descriptor(PostgresJSONType()) + else: + return dialect.type_descriptor(self.impl) + + def process_bind_param(self, value, dialect): + if dialect.name == 'postgresql' and has_postgres_json: + return value + if value is not None: + value = json.dumps(value) + return value + + def process_result_value(self, value, dialect): + if dialect.name == 'postgresql': + return value + if value is not None: + value = json.loads(value) + return value + + +def generic_relationship(*args, **kwargs): + return GenericRelationshipProperty(*args, **kwargs) + + +def get_column_key(model, column): + """ + Return the key for given column in given model. + + :param model: SQLAlchemy declarative model object + + :: + + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column('_name', sa.String) + + + get_column_key(User, User.__table__.c._name) # 'name' + + .. versionadded: 0.26.5 + + .. versionchanged: 0.27.11 + Throws UnmappedColumnError instead of ValueError when no property was + found for given column. This is consistent with how SQLAlchemy works. + """ + mapper = sa.inspect(model) + try: + return mapper.get_property_by_column(column).key + except sa.orm.exc.UnmappedColumnError: + for key, c in mapper.columns.items(): + if c.name == column.name and c.table is column.table: + return key + raise sa.orm.exc.UnmappedColumnError( + 'No column %s is configured on mapper %s...' % + (column, mapper) + ) + + +def get_declarative_base(model): + """ + Returns the declarative base for given model class. + + :param model: SQLAlchemy declarative model + """ + for parent in model.__bases__: + try: + parent.metadata + return get_declarative_base(parent) + except AttributeError: + pass + return model + + +def get_primary_keys(mixed): + """ + Return an OrderedDict of all primary keys for given Table object, + declarative class or declarative class instance. + + :param mixed: + SA Table object, SA declarative class or SA declarative class instance + + :: + + get_primary_keys(User) + + get_primary_keys(User()) + + get_primary_keys(User.__table__) + + get_primary_keys(User.__mapper__) + + get_primary_keys(sa.orm.aliased(User)) + + get_primary_keys(sa.orm.aliased(User.__table__)) + + + .. versionchanged: 0.25.3 + Made the function return an ordered dictionary instead of generator. + This change was made to support primary key aliases. + + Renamed this function to 'get_primary_keys', formerly 'primary_keys' + + .. seealso:: :func:`get_columns` + """ + return OrderedDict( + ( + (key, column) for key, column in _get_columns(mixed).items() + if column.primary_key + ) + ) + + +def has_changes(obj, attrs=None, exclude=None): + """ + Simple shortcut function for checking if given attributes of given + declarative model object have changed during the session. Without + parameters this checks if given object has any modificiations. Additionally + exclude parameter can be given to check if given object has any changes + in any attributes other than the ones given in exclude. + + + :: + + + from sqlalchemy_continuum.sa_utils import has_changes + + + user = User() + + has_changes(user, 'name') # False + + user.name = 'someone' + + has_changes(user, 'name') # True + + has_changes(user) # True + + + You can check multiple attributes as well. + :: + + + has_changes(user, ['age']) # True + + has_changes(user, ['name', 'age']) # True + + + This function also supports excluding certain attributes. + + :: + + has_changes(user, exclude=['name']) # False + + has_changes(user, exclude=['age']) # True + + .. versionchanged: 0.26.6 + Added support for multiple attributes and exclude parameter. + + :param obj: SQLAlchemy declarative model object + :param attrs: Names of the attributes + :param exclude: Names of the attributes to exclude + """ + if attrs: + if isinstance(attrs, str): + return ( + sa.inspect(obj) + .attrs + .get(attrs) + .history + .has_changes() + ) + else: + return any(has_changes(obj, attr) for attr in attrs) + else: + if exclude is None: + exclude = [] + return any( + attr.history.has_changes() + for key, attr in sa.inspect(obj).attrs.items() + if key not in exclude + ) + + +def identity(obj_or_class): + """ + Return the identity of given sqlalchemy declarative model class or instance + as a tuple. This differs from obj._sa_instance_state.identity in a way that + it always returns the identity even if object is still in transient state ( + new object that is not yet persisted into database). Also for classes it + returns the identity attributes. + + :: + + from sqlalchemy import inspect + from sqlalchemy_continuum.sa_utils import identity + + + user = User(name='John Matrix') + session.add(user) + identity(user) # None + inspect(user).identity # None + + session.flush() # User now has id but is still in transient state + + identity(user) # (1,) + inspect(user).identity # None + + session.commit() + + identity(user) # (1,) + inspect(user).identity # (1, ) + + + You can also use identity for classes:: + + + identity(User) # (User.id, ) + + .. versionadded: 0.21.0 + + :param obj: SQLAlchemy declarative model object + """ + return tuple( + getattr(obj_or_class, column_key) + for column_key in get_primary_keys(obj_or_class).keys() + ) + + +def naturally_equivalent(obj, obj2): + """ + Returns whether two given SQLAlchemy declarative instances are + naturally equivalent (all their non-primary key properties are equivalent). + + + :: + + from sqlalchemy_continuum.sa_utils import naturally_equivalent + + + user = User(name='someone') + user2 = User(name='someone') + + user == user2 # False + + naturally_equivalent(user, user2) # True + + + :param obj: SQLAlchemy declarative model object + :param obj2: SQLAlchemy declarative model object to compare with `obj` + """ + for column_key, column in sa.inspect(obj.__class__).columns.items(): + if column.primary_key: + continue + + if not (getattr(obj, column_key) == getattr(obj2, column_key)): + return False + return True diff --git a/sqlalchemy_continuum/unit_of_work.py b/sqlalchemy_continuum/unit_of_work.py index 0c270a33..2fa2d410 100644 --- a/sqlalchemy_continuum/unit_of_work.py +++ b/sqlalchemy_continuum/unit_of_work.py @@ -1,8 +1,9 @@ from copy import copy import sqlalchemy as sa -from sqlalchemy_utils.functions.orm import get_primary_keys, identity + from .operation import Operations +from .sa_utils import get_primary_keys, identity from .utils import ( end_tx_column_name, version_class, diff --git a/sqlalchemy_continuum/utils.py b/sqlalchemy_continuum/utils.py index 7953391d..fe99779a 100644 --- a/sqlalchemy_continuum/utils.py +++ b/sqlalchemy_continuum/utils.py @@ -5,13 +5,9 @@ import sqlalchemy as sa from sqlalchemy.orm.attributes import get_history from sqlalchemy.orm.util import AliasedClass -from sqlalchemy_utils.functions import ( - get_primary_keys, - identity, - naturally_equivalent, -) from .exc import ClassNotVersioned +from .sa_utils import get_primary_keys, identity, naturally_equivalent def get_versioning_manager(obj_or_class): From 8c471fe430626ec9b00f44130a603a3348718d63 Mon Sep 17 00:00:00 2001 From: Matt F Date: Wed, 7 Aug 2024 16:27:09 -0400 Subject: [PATCH 3/7] Removes use of OrderedDict Updates the `get_primary_keys` method to only return the primary key columns (without their values) in a list, since the values went unused. The method is thus renamed to `get_primary_key_columns`. --- sqlalchemy_continuum/fetcher.py | 6 ++--- sqlalchemy_continuum/sa_utils.py | 38 ++++++++++------------------ sqlalchemy_continuum/unit_of_work.py | 4 +-- sqlalchemy_continuum/utils.py | 4 +-- 4 files changed, 20 insertions(+), 32 deletions(-) diff --git a/sqlalchemy_continuum/fetcher.py b/sqlalchemy_continuum/fetcher.py index dfd9e1c7..af3c6133 100644 --- a/sqlalchemy_continuum/fetcher.py +++ b/sqlalchemy_continuum/fetcher.py @@ -1,14 +1,14 @@ import operator import sqlalchemy as sa -from .sa_utils import get_primary_keys, identity +from .sa_utils import get_primary_key_columns, identity from .utils import tx_column_name, end_tx_column_name def parent_identity(obj_or_class): return tuple( getattr(obj_or_class, column_key) - for column_key in get_primary_keys(obj_or_class).keys() + for column_key in get_primary_key_columns(obj_or_class) if column_key != tx_column_name(obj_or_class) ) @@ -84,7 +84,7 @@ def _transaction_id_subquery(self, obj, next_or_prev='next', alias=None): ), *[ getattr(attrs, pk) == getattr(obj, pk) - for pk in get_primary_keys(obj.__class__) + for pk in get_primary_key_columns(obj.__class__) if pk != tx_column_name(obj) ] ) diff --git a/sqlalchemy_continuum/sa_utils.py b/sqlalchemy_continuum/sa_utils.py index abfee583..815c8e27 100644 --- a/sqlalchemy_continuum/sa_utils.py +++ b/sqlalchemy_continuum/sa_utils.py @@ -1,4 +1,3 @@ -from collections import OrderedDict from collections.abc import Iterable from inspect import isclass import json @@ -394,9 +393,9 @@ def get_declarative_base(model): return model -def get_primary_keys(mixed): +def get_primary_key_columns(mixed): """ - Return an OrderedDict of all primary keys for given Table object, + Return all primary key names for given Table object, declarative class or declarative class instance. :param mixed: @@ -404,33 +403,22 @@ def get_primary_keys(mixed): :: - get_primary_keys(User) + get_primary_key_columns(User) - get_primary_keys(User()) + get_primary_key_columns(User()) - get_primary_keys(User.__table__) + get_primary_key_columns(User.__table__) - get_primary_keys(User.__mapper__) + get_primary_key_columns(User.__mapper__) - get_primary_keys(sa.orm.aliased(User)) + get_primary_key_columns(sa.orm.aliased(User)) - get_primary_keys(sa.orm.aliased(User.__table__)) - - - .. versionchanged: 0.25.3 - Made the function return an ordered dictionary instead of generator. - This change was made to support primary key aliases. - - Renamed this function to 'get_primary_keys', formerly 'primary_keys' - - .. seealso:: :func:`get_columns` + get_primary_key_columns(sa.orm.aliased(User.__table__)) """ - return OrderedDict( - ( - (key, column) for key, column in _get_columns(mixed).items() - if column.primary_key - ) - ) + return [ + key for key, column in _get_columns(mixed).items() + if column.primary_key + ] def has_changes(obj, attrs=None, exclude=None): @@ -545,7 +533,7 @@ def identity(obj_or_class): """ return tuple( getattr(obj_or_class, column_key) - for column_key in get_primary_keys(obj_or_class).keys() + for column_key in get_primary_key_columns(obj_or_class) ) diff --git a/sqlalchemy_continuum/unit_of_work.py b/sqlalchemy_continuum/unit_of_work.py index 2fa2d410..1bb21d21 100644 --- a/sqlalchemy_continuum/unit_of_work.py +++ b/sqlalchemy_continuum/unit_of_work.py @@ -3,7 +3,7 @@ import sqlalchemy as sa from .operation import Operations -from .sa_utils import get_primary_keys, identity +from .sa_utils import get_primary_key_columns, identity from .utils import ( end_tx_column_name, version_class, @@ -257,7 +257,7 @@ def update_version_validity(self, parent, version_obj): *[ getattr(version_obj, pk) == getattr(class_.__table__.c, pk) - for pk in get_primary_keys(class_) + for pk in get_primary_key_columns(class_) if pk != tx_column_name(class_) ] ) diff --git a/sqlalchemy_continuum/utils.py b/sqlalchemy_continuum/utils.py index fe99779a..359ea734 100644 --- a/sqlalchemy_continuum/utils.py +++ b/sqlalchemy_continuum/utils.py @@ -7,7 +7,7 @@ from sqlalchemy.orm.util import AliasedClass from .exc import ClassNotVersioned -from .sa_utils import get_primary_keys, identity, naturally_equivalent +from .sa_utils import get_primary_key_columns, identity, naturally_equivalent def get_versioning_manager(obj_or_class): @@ -386,7 +386,7 @@ def count_versions(obj): table_name = manager.option(obj, 'table_name') % obj.__table__.name criteria = [ '%s = %r' % (pk, getattr(obj, pk)) - for pk in get_primary_keys(obj) + for pk in get_primary_key_columns(obj) ] query = sa.text('SELECT COUNT(1) FROM %s WHERE %s' % ( table_name, From d9074841503ce3042678b75108845bd7e819132a Mon Sep 17 00:00:00 2001 From: Matt F Date: Wed, 7 Aug 2024 17:29:04 -0400 Subject: [PATCH 4/7] Moves plugin-specific code to plugins module Quite a bit of the code migrated in from SQLAlchemy-Utils only gets used within the plugins module. There's no need for us to incur the initialization expense if those classes and methods aren't needed or used. This will only actually have an initialization benefit if the plugins module does not load every plugin upon import. However, having the code co-located anyway makes sense. --- sqlalchemy_continuum/plugins/activity.py | 300 ++++++++++++++++++++++- sqlalchemy_continuum/plugins/flask.py | 3 +- sqlalchemy_continuum/sa_utils.py | 295 +--------------------- 3 files changed, 299 insertions(+), 299 deletions(-) diff --git a/sqlalchemy_continuum/plugins/activity.py b/sqlalchemy_continuum/plugins/activity.py index ba51f494..53f47573 100644 --- a/sqlalchemy_continuum/plugins/activity.py +++ b/sqlalchemy_continuum/plugins/activity.py @@ -3,7 +3,7 @@ individual entities. If you use ActivityPlugin you probably don't need to use TransactionChanges nor TransactionMeta plugins. -You can initalize the ActivityPlugin by adding it to versioning manager. +You can initialize the ActivityPlugin by adding it to versioning manager. :: @@ -47,7 +47,8 @@ transactions with the target and object. This allows each activity to also have object_version and target_version relationships for introspecting what those objects and targets were in given point in time. All these relationship -properties use `generic relationships`_ of the SQLAlchemy-Utils package. +properties use `generic relationships`_ ported from the SQLAlchemy-Utils +package. Limitations ^^^^^^^^^^^ @@ -184,21 +185,54 @@ .. _activity stream specification: - http://www.activitystrea.ms + https://www.activitystrea.ms .. _generic relationships: https://sqlalchemy-utils.readthedocs.io/en/latest/generic_relationship.html """ +from collections.abc import Iterable +import json import sqlalchemy as sa +from sqlalchemy.dialects.postgresql.base import ischema_names from sqlalchemy.ext.hybrid import hybrid_property from sqlalchemy.inspection import inspect +from sqlalchemy.orm import ColumnProperty, class_mapper +from sqlalchemy.orm.attributes import ( + ScalarAttributeImpl, + register_attribute, +) +from sqlalchemy.orm.base import PASSIVE_OFF +from sqlalchemy.orm.interfaces import MapperProperty, PropComparator +from sqlalchemy.orm.session import _state_session +from sqlalchemy.util import set_creation_order from .base import Plugin from ..factory import ModelFactory -from ..sa_utils import JSONType, generic_relationship +from ..sa_utils import identity from ..utils import version_class, version_obj +try: + from sqlalchemy.dialects.postgresql import JSON + has_postgres_json = True +except ImportError: + class PostgresJSONType(sa.types.UserDefinedType): + """ + Text search vector type for postgresql. + """ + def get_col_spec(self): + return 'json' + + ischema_names['json'] = PostgresJSONType + has_postgres_json = False + +def _get_class_registry(class_): + try: + return class_.registry._class_registry + except AttributeError: # SQLAlchemy <1.4 + return class_._decl_class_registry + + class ActivityBase(object): id = sa.Column( sa.BigInteger, @@ -341,3 +375,261 @@ def before_flush(self, uow, session): def after_version_class_built(self, parent_cls, version_cls): pass + + +class GenericAttributeImpl(ScalarAttributeImpl): + def __init__(self, *args, **kwargs): + """ + The constructor of attributes.AttributeImpl changed in SQLAlchemy 2.0.22, + adding a 'default_function' required positional argument before 'dispatch'. + This adjustment ensures compatibility across versions by inserting None for + 'default_function' in versions >= 2.0.22. + + Arguments received: (class, key, dispatch) + Required by AttributeImpl: (class, key, default_function, dispatch) + Setting None as default_function here. + """ + # Adjust for SQLAlchemy version change + sqlalchemy_version = tuple(map(int, sa.__version__.split('.'))) + if sqlalchemy_version >= (2, 0, 22): + args = (*args[:2], None, *args[2:]) + + super().__init__(*args, **kwargs) + + def get(self, state, dict_, passive=PASSIVE_OFF): + if self.key in dict_: + return dict_[self.key] + + # Retrieve the session bound to the state in order to perform + # a lazy query for the attribute. + # TODO: replace this with sa.orm.session.object_session? + session = _state_session(state) + if session is None: + # State is not bound to a session; we cannot proceed. + return None + + # Find class for discriminator. + # TODO: Perhaps optimize with some sort of lookup? + discriminator = self.get_state_discriminator(state) + target_class = _get_class_registry(state.class_).get(discriminator) + + if target_class is None: + # Unknown discriminator; return nothing. + return None + + id = self.get_state_id(state) + + try: + target = session.get(target_class, id) + except AttributeError: + # sqlalchemy 1.3 + target = session.query(target_class).get(id) + + # Return found (or not found) target. + return target + + def get_state_discriminator(self, state): + discriminator = self.parent_token.discriminator + if isinstance(discriminator, hybrid_property): + return getattr(state.obj(), discriminator.__name__) + else: + return state.attrs[discriminator.key].value + + def get_state_id(self, state): + # Lookup row with the discriminator and id. + return tuple(state.attrs[id.key].value for id in self.parent_token.id) + + def set(self, state, dict_, initiator, + passive=PASSIVE_OFF, + check_old=None, + pop=False): + + # Set us on the state. + dict_[self.key] = initiator + + if initiator is None: + # Nullify relationship args + for id in self.parent_token.id: + dict_[id.key] = None + dict_[self.parent_token.discriminator.key] = None + else: + # Get the primary key of the initiator and ensure we + # can support this assignment. + class_ = type(initiator) + mapper = class_mapper(class_) + + pk = mapper.identity_key_from_instance(initiator)[1] + + # Set the identifier and the discriminator. + discriminator = class_.__name__ + + for index, id in enumerate(self.parent_token.id): + dict_[id.key] = pk[index] + dict_[self.parent_token.discriminator.key] = discriminator + + +class GenericRelationshipProperty(MapperProperty): + """A generic form of the relationship property. + + Creates a 1 to many relationship between the parent model + and any other models using a discriminator (the table name). + + :param discriminator + Field to discriminate which model we are referring to. + :param id: + Field to point to the model we are referring to. + """ + + def __init__(self, discriminator, id, doc=None): + super().__init__() + self._discriminator_col = discriminator + self._id_cols = id + self._id = None + self._discriminator = None + self.doc = doc + + set_creation_order(self) + + def _column_to_property(self, column): + if isinstance(column, hybrid_property): + attr_key = column.__name__ + for key, attr in self.parent.all_orm_descriptors.items(): + if key == attr_key: + return attr + else: + for attr in self.parent.attrs.values(): + if isinstance(attr, ColumnProperty): + if attr.columns[0].name == column.name: + return attr + + def init(self): + def convert_strings(column): + if isinstance(column, str): + return self.parent.columns[column] + return column + + self._discriminator_col = convert_strings(self._discriminator_col) + self._id_cols = convert_strings(self._id_cols) + + if isinstance(self._id_cols, Iterable): + self._id_cols = list(map(convert_strings, self._id_cols)) + else: + self._id_cols = [self._id_cols] + + self.discriminator = self._column_to_property(self._discriminator_col) + + if self.discriminator is None: + raise ImproperlyConfigured( + 'Could not find discriminator descriptor.' + ) + + self.id = list(map(self._column_to_property, self._id_cols)) + + class Comparator(PropComparator): + def __init__(self, prop, parentmapper): + self.property = prop + self._parententity = parentmapper + + def __eq__(self, other): + discriminator = type(other).__name__ + q = self.property._discriminator_col == discriminator + other_id = identity(other) + for index, id in enumerate(self.property._id_cols): + q &= id == other_id[index] + return q + + def __ne__(self, other): + return ~(self == other) + + def is_type(self, other): + mapper = sa.inspect(other) + # Iterate through the weak sequence in order to get the actual + # mappers + class_names = [other.__name__] + class_names.extend([ + submapper.class_.__name__ + for submapper in mapper._inheriting_mappers + ]) + + return self.property._discriminator_col.in_(class_names) + + def instrument_class(self, mapper): + register_attribute( + mapper.class_, + self.key, + comparator=self.Comparator(self, mapper), + parententity=mapper, + doc=self.doc, + impl_class=GenericAttributeImpl, + parent_token=self + ) + + +class ImproperlyConfigured(Exception): + """ + SQLAlchemy-Continuum is improperly configured; normally due to usage of + a utility that depends on a missing library. + """ + + +class JSONType(sa.types.TypeDecorator): + """ + JSONType offers way of saving JSON data structures to database. On + PostgreSQL the underlying implementation of this data type is 'json' while + on other databases its simply 'text'. + + :: + + + from sqlalchemy_continuum.plugins.activity import JSONType + + + class Product(Base): + __tablename__ = 'product' + id = sa.Column(sa.Integer, autoincrement=True) + name = sa.Column(sa.Unicode(50)) + details = sa.Column(JSONType) + + + product = Product() + product.details = { + 'color': 'red', + 'type': 'car', + 'max-speed': '400 mph' + } + session.commit() + """ + impl = sa.UnicodeText + hashable = False + cache_ok = True + + def __init__(self, *args, **kwargs): + super(JSONType, self).__init__(*args, **kwargs) + + def load_dialect_impl(self, dialect): + if dialect.name == 'postgresql': + # Use the native JSON type. + if has_postgres_json: + return dialect.type_descriptor(JSON()) + else: + return dialect.type_descriptor(PostgresJSONType()) + else: + return dialect.type_descriptor(self.impl) + + def process_bind_param(self, value, dialect): + if dialect.name == 'postgresql' and has_postgres_json: + return value + if value is not None: + value = json.dumps(value) + return value + + def process_result_value(self, value, dialect): + if dialect.name == 'postgresql': + return value + if value is not None: + value = json.loads(value) + return value + + +def generic_relationship(*args, **kwargs): + return GenericRelationshipProperty(*args, **kwargs) diff --git a/sqlalchemy_continuum/plugins/flask.py b/sqlalchemy_continuum/plugins/flask.py index cc3ed82a..abcf0f92 100644 --- a/sqlalchemy_continuum/plugins/flask.py +++ b/sqlalchemy_continuum/plugins/flask.py @@ -26,7 +26,6 @@ pass from .base import Plugin -from ..sa_utils import ImproperlyConfigured def fetch_current_user_id(): @@ -60,7 +59,7 @@ def __init__( ) if not flask: - raise ImproperlyConfigured( + raise ImportError( 'Flask is required with FlaskPlugin. Please install Flask by' ' running pip install Flask' ) diff --git a/sqlalchemy_continuum/sa_utils.py b/sqlalchemy_continuum/sa_utils.py index 815c8e27..38f9a403 100644 --- a/sqlalchemy_continuum/sa_utils.py +++ b/sqlalchemy_continuum/sa_utils.py @@ -1,41 +1,8 @@ -from collections.abc import Iterable from inspect import isclass -import json import sqlalchemy as sa -from sqlalchemy.dialects.postgresql.base import ischema_names -from sqlalchemy.ext.hybrid import hybrid_property -from sqlalchemy.orm import ColumnProperty, class_mapper -from sqlalchemy.orm.attributes import ( - InstrumentedAttribute, - ScalarAttributeImpl, - register_attribute, -) -from sqlalchemy.orm.base import PASSIVE_OFF -from sqlalchemy.orm.interfaces import MapperProperty, PropComparator -from sqlalchemy.orm.session import _state_session -from sqlalchemy.util import set_creation_order - -try: - from sqlalchemy.dialects.postgresql import JSON - has_postgres_json = True -except ImportError: - class PostgresJSONType(sa.types.UserDefinedType): - """ - Text search vector type for postgresql. - """ - def get_col_spec(self): - return 'json' - - ischema_names['json'] = PostgresJSONType - has_postgres_json = False - - -def _get_class_registry(class_): - try: - return class_.registry._class_registry - except AttributeError: # SQLAlchemy <1.4 - return class_._decl_class_registry +from sqlalchemy.orm import ColumnProperty +from sqlalchemy.orm.attributes import InstrumentedAttribute def _get_columns(mixed): @@ -85,264 +52,6 @@ def _get_columns(mixed): return sa.inspect(mixed).columns -class GenericAttributeImpl(ScalarAttributeImpl): - def __init__(self, *args, **kwargs): - """ - The constructor of attributes.AttributeImpl changed in SQLAlchemy 2.0.22, - adding a 'default_function' required positional argument before 'dispatch'. - This adjustment ensures compatibility across versions by inserting None for - 'default_function' in versions >= 2.0.22. - - Arguments received: (class, key, dispatch) - Required by AttributeImpl: (class, key, default_function, dispatch) - Setting None as default_function here. - """ - # Adjust for SQLAlchemy version change - sqlalchemy_version = tuple(map(int, sa.__version__.split('.'))) - if sqlalchemy_version >= (2, 0, 22): - args = (*args[:2], None, *args[2:]) - - super().__init__(*args, **kwargs) - - def get(self, state, dict_, passive=PASSIVE_OFF): - if self.key in dict_: - return dict_[self.key] - - # Retrieve the session bound to the state in order to perform - # a lazy query for the attribute. - # TODO: replace this with sa.orm.session.object_session? - session = _state_session(state) - if session is None: - # State is not bound to a session; we cannot proceed. - return None - - # Find class for discriminator. - # TODO: Perhaps optimize with some sort of lookup? - discriminator = self.get_state_discriminator(state) - target_class = _get_class_registry(state.class_).get(discriminator) - - if target_class is None: - # Unknown discriminator; return nothing. - return None - - id = self.get_state_id(state) - - try: - target = session.get(target_class, id) - except AttributeError: - # sqlalchemy 1.3 - target = session.query(target_class).get(id) - - # Return found (or not found) target. - return target - - def get_state_discriminator(self, state): - discriminator = self.parent_token.discriminator - if isinstance(discriminator, hybrid_property): - return getattr(state.obj(), discriminator.__name__) - else: - return state.attrs[discriminator.key].value - - def get_state_id(self, state): - # Lookup row with the discriminator and id. - return tuple(state.attrs[id.key].value for id in self.parent_token.id) - - def set(self, state, dict_, initiator, - passive=PASSIVE_OFF, - check_old=None, - pop=False): - - # Set us on the state. - dict_[self.key] = initiator - - if initiator is None: - # Nullify relationship args - for id in self.parent_token.id: - dict_[id.key] = None - dict_[self.parent_token.discriminator.key] = None - else: - # Get the primary key of the initiator and ensure we - # can support this assignment. - class_ = type(initiator) - mapper = class_mapper(class_) - - pk = mapper.identity_key_from_instance(initiator)[1] - - # Set the identifier and the discriminator. - discriminator = class_.__name__ - - for index, id in enumerate(self.parent_token.id): - dict_[id.key] = pk[index] - dict_[self.parent_token.discriminator.key] = discriminator - - -class GenericRelationshipProperty(MapperProperty): - """A generic form of the relationship property. - - Creates a 1 to many relationship between the parent model - and any other models using a discriminator (the table name). - - :param discriminator - Field to discriminate which model we are referring to. - :param id: - Field to point to the model we are referring to. - """ - - def __init__(self, discriminator, id, doc=None): - super().__init__() - self._discriminator_col = discriminator - self._id_cols = id - self._id = None - self._discriminator = None - self.doc = doc - - set_creation_order(self) - - def _column_to_property(self, column): - if isinstance(column, hybrid_property): - attr_key = column.__name__ - for key, attr in self.parent.all_orm_descriptors.items(): - if key == attr_key: - return attr - else: - for attr in self.parent.attrs.values(): - if isinstance(attr, ColumnProperty): - if attr.columns[0].name == column.name: - return attr - - def init(self): - def convert_strings(column): - if isinstance(column, str): - return self.parent.columns[column] - return column - - self._discriminator_col = convert_strings(self._discriminator_col) - self._id_cols = convert_strings(self._id_cols) - - if isinstance(self._id_cols, Iterable): - self._id_cols = list(map(convert_strings, self._id_cols)) - else: - self._id_cols = [self._id_cols] - - self.discriminator = self._column_to_property(self._discriminator_col) - - if self.discriminator is None: - raise ImproperlyConfigured( - 'Could not find discriminator descriptor.' - ) - - self.id = list(map(self._column_to_property, self._id_cols)) - - class Comparator(PropComparator): - def __init__(self, prop, parentmapper): - self.property = prop - self._parententity = parentmapper - - def __eq__(self, other): - discriminator = type(other).__name__ - q = self.property._discriminator_col == discriminator - other_id = identity(other) - for index, id in enumerate(self.property._id_cols): - q &= id == other_id[index] - return q - - def __ne__(self, other): - return ~(self == other) - - def is_type(self, other): - mapper = sa.inspect(other) - # Iterate through the weak sequence in order to get the actual - # mappers - class_names = [other.__name__] - class_names.extend([ - submapper.class_.__name__ - for submapper in mapper._inheriting_mappers - ]) - - return self.property._discriminator_col.in_(class_names) - - def instrument_class(self, mapper): - register_attribute( - mapper.class_, - self.key, - comparator=self.Comparator(self, mapper), - parententity=mapper, - doc=self.doc, - impl_class=GenericAttributeImpl, - parent_token=self - ) - - -class ImproperlyConfigured(Exception): - """ - SQLAlchemy-Continuum is improperly configured; normally due to usage of - a utility that depends on a missing library. - """ - - -class JSONType(sa.types.TypeDecorator): - """ - JSONType offers way of saving JSON data structures to database. On - PostgreSQL the underlying implementation of this data type is 'json' while - on other databases its simply 'text'. - - :: - - - from sqlalchemy_continuum.sa_utils import JSONType - - - class Product(Base): - __tablename__ = 'product' - id = sa.Column(sa.Integer, autoincrement=True) - name = sa.Column(sa.Unicode(50)) - details = sa.Column(JSONType) - - - product = Product() - product.details = { - 'color': 'red', - 'type': 'car', - 'max-speed': '400 mph' - } - session.commit() - """ - impl = sa.UnicodeText - hashable = False - cache_ok = True - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def load_dialect_impl(self, dialect): - if dialect.name == 'postgresql': - # Use the native JSON type. - if has_postgres_json: - return dialect.type_descriptor(JSON()) - else: - return dialect.type_descriptor(PostgresJSONType()) - else: - return dialect.type_descriptor(self.impl) - - def process_bind_param(self, value, dialect): - if dialect.name == 'postgresql' and has_postgres_json: - return value - if value is not None: - value = json.dumps(value) - return value - - def process_result_value(self, value, dialect): - if dialect.name == 'postgresql': - return value - if value is not None: - value = json.loads(value) - return value - - -def generic_relationship(*args, **kwargs): - return GenericRelationshipProperty(*args, **kwargs) - - def get_column_key(model, column): """ Return the key for given column in given model. From 3e9038d4e97bfbfa8318fefb08df152b4c5b8426 Mon Sep 17 00:00:00 2001 From: Matt F Date: Wed, 7 Aug 2024 17:41:15 -0400 Subject: [PATCH 5/7] Makes get_column_key private to manager.py The function only gets used in one place, so we can move it there and mark it private. --- sqlalchemy_continuum/manager.py | 38 ++++++++++++++++++++++++++++++-- sqlalchemy_continuum/sa_utils.py | 35 ----------------------------- 2 files changed, 36 insertions(+), 37 deletions(-) diff --git a/sqlalchemy_continuum/manager.py b/sqlalchemy_continuum/manager.py index 451d15e2..2fc43ce4 100644 --- a/sqlalchemy_continuum/manager.py +++ b/sqlalchemy_continuum/manager.py @@ -7,12 +7,46 @@ from .fetcher import SubqueryFetcher, ValidityFetcher from .operation import Operation from .plugins import PluginCollection -from .sa_utils import get_column_key from .transaction import TransactionFactory from .unit_of_work import UnitOfWork from .utils import is_modified, is_versioned, version_table +def _get_column_key(model, column): + """ + Return the key for given column in given model. + + :param model: SQLAlchemy declarative model object + + :: + + class User(Base): + __tablename__ = 'user' + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column('_name', sa.String) + + + get_column_key(User, User.__table__.c._name) # 'name' + + .. versionadded: 0.26.5 + + .. versionchanged: 0.27.11 + Throws UnmappedColumnError instead of ValueError when no property was + found for given column. This is consistent with how SQLAlchemy works. + """ + mapper = sa.inspect(model) + try: + return mapper.get_property_by_column(column).key + except sa.orm.exc.UnmappedColumnError: + for key, c in mapper.columns.items(): + if c.name == column.name and c.table is column.table: + return key + raise sa.orm.exc.UnmappedColumnError( + 'No column %s is configured on mapper %s...' % + (column, mapper) + ) + + def tracked_operation(func): @wraps(func) def wrapper(self, mapper, connection, target): @@ -155,7 +189,7 @@ def create_transaction_model(self): def is_excluded_column(self, model, column): try: - key = get_column_key(model, column) + key = _get_column_key(model, column) except sa.orm.exc.UnmappedColumnError: return False diff --git a/sqlalchemy_continuum/sa_utils.py b/sqlalchemy_continuum/sa_utils.py index 38f9a403..ae40f70b 100644 --- a/sqlalchemy_continuum/sa_utils.py +++ b/sqlalchemy_continuum/sa_utils.py @@ -52,41 +52,6 @@ def _get_columns(mixed): return sa.inspect(mixed).columns -def get_column_key(model, column): - """ - Return the key for given column in given model. - - :param model: SQLAlchemy declarative model object - - :: - - class User(Base): - __tablename__ = 'user' - id = sa.Column(sa.Integer, primary_key=True) - name = sa.Column('_name', sa.String) - - - get_column_key(User, User.__table__.c._name) # 'name' - - .. versionadded: 0.26.5 - - .. versionchanged: 0.27.11 - Throws UnmappedColumnError instead of ValueError when no property was - found for given column. This is consistent with how SQLAlchemy works. - """ - mapper = sa.inspect(model) - try: - return mapper.get_property_by_column(column).key - except sa.orm.exc.UnmappedColumnError: - for key, c in mapper.columns.items(): - if c.name == column.name and c.table is column.table: - return key - raise sa.orm.exc.UnmappedColumnError( - 'No column %s is configured on mapper %s...' % - (column, mapper) - ) - - def get_declarative_base(model): """ Returns the declarative base for given model class. From d1452fd1c628b932632f7b5fd5347f4ccd9beeab Mon Sep 17 00:00:00 2001 From: Matt F Date: Wed, 7 Aug 2024 17:41:15 -0400 Subject: [PATCH 6/7] Makes naturally_equivalent private to utils.py The function only gets used in one place, so we can move it there and mark it private. --- sqlalchemy_continuum/sa_utils.py | 31 ------------------------------- sqlalchemy_continuum/utils.py | 31 +++++++++++++++++++++++++++++-- 2 files changed, 29 insertions(+), 33 deletions(-) diff --git a/sqlalchemy_continuum/sa_utils.py b/sqlalchemy_continuum/sa_utils.py index ae40f70b..c66fd115 100644 --- a/sqlalchemy_continuum/sa_utils.py +++ b/sqlalchemy_continuum/sa_utils.py @@ -209,34 +209,3 @@ def identity(obj_or_class): getattr(obj_or_class, column_key) for column_key in get_primary_key_columns(obj_or_class) ) - - -def naturally_equivalent(obj, obj2): - """ - Returns whether two given SQLAlchemy declarative instances are - naturally equivalent (all their non-primary key properties are equivalent). - - - :: - - from sqlalchemy_continuum.sa_utils import naturally_equivalent - - - user = User(name='someone') - user2 = User(name='someone') - - user == user2 # False - - naturally_equivalent(user, user2) # True - - - :param obj: SQLAlchemy declarative model object - :param obj2: SQLAlchemy declarative model object to compare with `obj` - """ - for column_key, column in sa.inspect(obj.__class__).columns.items(): - if column.primary_key: - continue - - if not (getattr(obj, column_key) == getattr(obj2, column_key)): - return False - return True diff --git a/sqlalchemy_continuum/utils.py b/sqlalchemy_continuum/utils.py index 359ea734..ed479515 100644 --- a/sqlalchemy_continuum/utils.py +++ b/sqlalchemy_continuum/utils.py @@ -7,7 +7,34 @@ from sqlalchemy.orm.util import AliasedClass from .exc import ClassNotVersioned -from .sa_utils import get_primary_key_columns, identity, naturally_equivalent +from .sa_utils import get_primary_key_columns, identity + + +def _naturally_equivalent(obj, obj2): + """ + Returns whether two given SQLAlchemy declarative instances are + naturally equivalent (all their non-primary key properties are equivalent). + + + :: + user = User(name='someone') + user2 = User(name='someone') + + user == user2 # False + + _naturally_equivalent(user, user2) # True + + + :param obj: SQLAlchemy declarative model object + :param obj2: SQLAlchemy declarative model object to compare with `obj` + """ + for column_key, column in sa.inspect(obj.__class__).columns.items(): + if column.primary_key: + continue + + if not (getattr(obj, column_key) == getattr(obj2, column_key)): + return False + return True def get_versioning_manager(obj_or_class): @@ -256,7 +283,7 @@ def vacuum(session, model, yield_per=1000): version_id = getattr(version, primary_key_col) if versions[version_id]: prev_version = versions[version_id][-1] - if naturally_equivalent(prev_version, version): + if _naturally_equivalent(prev_version, version): session.delete(version) else: versions[version_id].append(version) From 475d76856ec5233e1de38e76158bc72fccd7c44d Mon Sep 17 00:00:00 2001 From: Matt F Date: Thu, 8 Aug 2024 13:29:14 -0400 Subject: [PATCH 7/7] Switches to a relative import Using a relative import allows the library to be more-easily vendored without having to tweak the import line. --- sqlalchemy_continuum/dialects/postgresql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sqlalchemy_continuum/dialects/postgresql.py b/sqlalchemy_continuum/dialects/postgresql.py index 8681fe35..87d3f6b4 100644 --- a/sqlalchemy_continuum/dialects/postgresql.py +++ b/sqlalchemy_continuum/dialects/postgresql.py @@ -1,6 +1,6 @@ import sqlalchemy as sa -from sqlalchemy_continuum.plugins import PropertyModTrackerPlugin +from ..plugins.property_mod_tracker import PropertyModTrackerPlugin trigger_sql = """