Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

correctly handle signals in nested transactions #653

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions flask_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,14 @@ class _SessionSignalEvents(object):
@classmethod
def register(cls, session):
if not hasattr(session, '_model_changes'):
session._model_changes = {}
session._model_changes = []

event.listen(session, 'before_flush', cls.record_ops)
event.listen(session, 'before_commit', cls.record_ops)
event.listen(session, 'before_commit', cls.before_commit)
event.listen(session, 'after_commit', cls.after_commit)
event.listen(session, 'after_rollback', cls.after_rollback)
event.listen(session, 'after_transaction_create', cls.after_transaction_create)

@classmethod
def unregister(cls, session):
Expand All @@ -184,6 +185,7 @@ def unregister(cls, session):
event.remove(session, 'before_commit', cls.before_commit)
event.remove(session, 'after_commit', cls.after_commit)
event.remove(session, 'after_rollback', cls.after_rollback)
event.remove(session, 'after_transaction_create', cls.after_transaction_create)

@staticmethod
def record_ops(session, flush_context=None, instances=None):
Expand All @@ -196,28 +198,54 @@ def record_ops(session, flush_context=None, instances=None):
for target in targets:
state = inspect(target)
key = state.identity_key if state.has_identity else id(target)
d[key] = (target, operation)
d[-1][key] = (target, operation)

@staticmethod
def after_transaction_create(session, transaction):
if transaction.parent and not transaction.nested:
return

try:
d = session._model_changes
except AttributeError:
return

d.append({})

@staticmethod
def before_commit(session):
if session.transaction.nested:
return

try:
d = session._model_changes
except AttributeError:
return

if d:
before_models_committed.send(session.app, changes=list(d.values()))
for level in d[1:]:
d[0].update(level)

if d[0]:
before_models_committed.send(session.app, changes=list(d[0].values()))

@staticmethod
def after_commit(session):
if session.transaction.nested:
return

try:
d = session._model_changes
except AttributeError:
return

if d:
models_committed.send(session.app, changes=list(d.values()))
d.clear()
for level in d[1:]:
d[0].update(level)

if d[0]:
models_committed.send(session.app, changes=list(d[0].values()))
del d[:]

@staticmethod
def after_rollback(session):
Expand All @@ -226,7 +254,10 @@ def after_rollback(session):
except AttributeError:
return

d.clear()
try:
del d[-1]
except IndexError:
pass


class _EngineDebuggingSignalEvents(object):
Expand Down
67 changes: 67 additions & 0 deletions tests/test_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pytest

import flask_sqlalchemy as fsa
import sqlalchemy as sa


pytestmark = pytest.mark.skipif(
Expand All @@ -16,6 +17,24 @@ def app(app):
return app


@pytest.fixture()
def db(db):
# required for correct handling of nested transactions, see
# https://docs.sqlalchemy.org/en/rel_1_0/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl
@sa.event.listens_for(db.engine, "connect")
def do_connect(dbapi_connection, connection_record):
# disable pysqlite's emitting of the BEGIN statement entirely.
# also stops it from emitting COMMIT before any DDL.
dbapi_connection.isolation_level = None

@sa.event.listens_for(db.engine, "begin")
def do_begin(conn):
# emit our own BEGIN
conn.execute("BEGIN")

return db


def test_before_committed(app, db, Todo):
class Namespace(object):
is_received = False
Expand Down Expand Up @@ -59,3 +78,51 @@ def committed(sender, changes):
assert recorded[0][0] == todo
assert recorded[0][1] == 'delete'
fsa.models_committed.disconnect(committed)


def test_model_signals_nested_transaction(db, Todo):
before_commit_recorded = []
commit_recorded = []

def before_committed(sender, changes):
before_commit_recorded.extend(changes)

def committed(sender, changes):
commit_recorded.extend(changes)

fsa.before_models_committed.connect(before_committed)
fsa.models_committed.connect(committed)
with db.session.begin_nested():
todo = Todo('Awesome', 'the text')
db.session.add(todo)
try:
with db.session.begin_nested():
todo2 = Todo('Bad', 'to rollback')
db.session.add(todo2)
raise Exception('raising to roll back')
except Exception:
pass
assert before_commit_recorded == []
assert commit_recorded == []
db.session.commit()
assert before_commit_recorded == [(todo, 'insert')]
assert commit_recorded == [(todo, 'insert')]
del before_commit_recorded[:]
del commit_recorded[:]
try:
with db.session.begin_nested():
todo = Todo('Great', 'the text')
db.session.add(todo)
with db.session.begin_nested():
todo2 = Todo('Bad', 'to rollback')
db.session.add(todo2)
raise Exception('raising to roll back')
except Exception:
pass
assert before_commit_recorded == []
assert commit_recorded == []
db.session.commit()
assert before_commit_recorded == []
assert commit_recorded == []
fsa.before_models_committed.disconnect(before_committed)
fsa.models_committed.disconnect(committed)