From 721f0e4c058260709d7fc893db584f1509486516 Mon Sep 17 00:00:00 2001 From: Alexander Neumann Date: Wed, 8 May 2024 18:58:01 +0200 Subject: [PATCH 1/4] add (async) tests that raise exceptions in prepare and conditions during may --- tests/test_async.py | 31 ++++++++++++++++++++++++++++++- tests/test_core.py | 25 +++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/tests/test_async.py b/tests/test_async.py index 4166475b..8df58726 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -12,7 +12,7 @@ from functools import partial import weakref from .test_core import TestTransitions, MachineError, TYPE_CHECKING -from .utils import DummyModel +from .utils import DummyModel, Stuff from .test_graphviz import pgv as gv from .test_pygraphviz import pgv @@ -513,6 +513,35 @@ async def run(): asyncio.run(run()) + def test_may_transition_with_exception(self): + + stuff = Stuff(machine_cls=self.machine_cls, extra_kwargs={"send_event": True}) + stuff.machine.add_transition(trigger="raises", source="A", dest="B", prepare=partial(stuff.this_raises, RuntimeError("Prepare Exception"))) + stuff.machine.add_transition(trigger="raises", source="B", dest="C", conditions=partial(stuff.this_raises, ValueError("Condition Exception"))) + stuff.machine.add_transition(trigger="works", source="A", dest="B") + + def process_exception(event_data): + assert event_data.error is not None + assert event_data.transition is not None + assert event_data.event.name == "raises" + assert event_data.machine == stuff.machine + + async def run(): + with self.assertRaises(RuntimeError): + await stuff.may_raises() + assert stuff.is_A() + assert await stuff.may_works() + assert await stuff.works() + with self.assertRaises(ValueError): + await stuff.may_raises() + assert stuff.is_B() + stuff.machine.on_exception.append(process_exception) + assert not await stuff.may_raises() + assert await stuff.to_A() + assert not await stuff.may_raises() + + asyncio.run(run()) + @skipIf(asyncio is None or (pgv is None and gv is None), "AsyncGraphMachine requires asyncio and (py)gaphviz") class TestAsyncGraphMachine(TestAsync): diff --git a/tests/test_core.py b/tests/test_core.py index 605b5142..cfcde6c2 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1294,3 +1294,28 @@ def test_may_transition_with_invalid_state(self): m = self.machine_cls(model=d, states=states, initial='A', auto_transitions=False) m.add_transition('walk', 'A', 'UNKNOWN') assert not d.may_walk() + + def test_may_transition_with_exception(self): + stuff = Stuff(machine_cls=self.machine_cls, extra_kwargs={"send_event": True}) + stuff.machine.add_transition(trigger="raises", source="A", dest="B", prepare=partial(stuff.this_raises, RuntimeError("Prepare Exception"))) + stuff.machine.add_transition(trigger="raises", source="B", dest="C", conditions=partial(stuff.this_raises, ValueError("Condition Exception"))) + stuff.machine.add_transition(trigger="works", source="A", dest="B") + + def process_exception(event_data): + assert event_data.error is not None + assert event_data.transition is not None + assert event_data.event.name == "raises" + assert event_data.machine == stuff.machine + + with self.assertRaises(RuntimeError): + stuff.may_raises() + assert stuff.is_A() + assert stuff.may_works() + assert stuff.works() + with self.assertRaises(ValueError): + stuff.may_raises() + assert stuff.is_B() + stuff.machine.on_exception.append(process_exception) + assert not stuff.may_raises() + assert stuff.to_A() + assert not stuff.may_raises() From 675cc25ea8065af754126d5f5582d64e4e5aec69 Mon Sep 17 00:00:00 2001 From: Alexander Neumann Date: Wed, 8 May 2024 18:58:49 +0200 Subject: [PATCH 2/4] this fixes #626 --- transitions/core.py | 23 +++++++++++------ transitions/extensions/asyncio.py | 41 ++++++++++++++++++++++--------- transitions/extensions/nesting.py | 19 ++++++++++---- 3 files changed, 58 insertions(+), 25 deletions(-) diff --git a/transitions/core.py b/transitions/core.py index a1968526..429c2ad9 100644 --- a/transitions/core.py +++ b/transitions/core.py @@ -876,23 +876,30 @@ def _checked_assignment(self, model, name, func): setattr(model, name, func) def _can_trigger(self, model, trigger, *args, **kwargs): - evt = EventData(None, None, self, model, args, kwargs) - state = self.get_model_state(model).name + state = self.get_model_state(model) + event_data = EventData(state, Event(name=trigger, machine=self), self, model, args, kwargs) for trigger_name in self.get_triggers(state): if trigger_name != trigger: continue - for transition in self.events[trigger_name].transitions[state]: + for transition in self.events[trigger_name].transitions[state.name]: try: _ = self.get_state(transition.dest) if transition.dest is not None else transition.source except ValueError: continue - evt.transition = transition - self.callbacks(self.prepare_event, evt) - self.callbacks(transition.prepare, evt) - if all(c.check(evt) for c in transition.conditions): - return True + event_data.transition = transition + try: + self.callbacks(self.prepare_event, event_data) + self.callbacks(transition.prepare, event_data) + if all(c.check(event_data) for c in transition.conditions): + return True + except BaseException as err: + event_data.error = err + if self.on_exception: + self.callbacks(self.on_exception, event_data) + else: + raise return False def _add_may_transition_func_for_trigger(self, trigger, model): diff --git a/transitions/extensions/asyncio.py b/transitions/extensions/asyncio.py index 7f46db28..c82674f6 100644 --- a/transitions/extensions/asyncio.py +++ b/transitions/extensions/asyncio.py @@ -425,21 +425,29 @@ def remove_model(self, model): self._transition_queue.extend(new_queue) async def _can_trigger(self, model, trigger, *args, **kwargs): - evt = AsyncEventData(None, None, self, model, args, kwargs) - state = self.get_model_state(model).name + state = self.get_model_state(model) + event_data = AsyncEventData(state, AsyncEvent(name=trigger, machine=self), self, model, args, kwargs) for trigger_name in self.get_triggers(state): if trigger_name != trigger: continue - for transition in self.events[trigger_name].transitions[state]: + for transition in self.events[trigger_name].transitions[state.name]: try: _ = self.get_state(transition.dest) if transition.dest is not None else transition.source except ValueError: continue - await self.callbacks(self.prepare_event, evt) - await self.callbacks(transition.prepare, evt) - if all(await self.await_all([partial(c.check, evt) for c in transition.conditions])): - return True + event_data.transition = transition + try: + await self.callbacks(self.prepare_event, event_data) + await self.callbacks(transition.prepare, event_data) + if all(await self.await_all([partial(c.check, event_data) for c in transition.conditions])): + return True + except BaseException as err: + event_data.error = err + if self.on_exception: + await self.callbacks(self.on_exception, event_data) + else: + raise return False def _process(self, trigger): @@ -552,20 +560,29 @@ async def _can_trigger(self, model, trigger, *args, **kwargs): return await self._can_trigger_nested(model, trigger, state_path, *args, **kwargs) async def _can_trigger_nested(self, model, trigger, path, *args, **kwargs): - evt = AsyncEventData(None, None, self, model, args, kwargs) if trigger in self.events: source_path = copy.copy(path) while source_path: + event_data = AsyncEventData(self.get_state(source_path), AsyncEvent(name=trigger, machine=self), self, + model, args, kwargs) state_name = self.state_cls.separator.join(source_path) for transition in self.events[trigger].transitions.get(state_name, []): try: _ = self.get_state(transition.dest) if transition.dest is not None else transition.source except ValueError: continue - await self.callbacks(self.prepare_event, evt) - await self.callbacks(transition.prepare, evt) - if all(await self.await_all([partial(c.check, evt) for c in transition.conditions])): - return True + event_data.transition = transition + try: + await self.callbacks(self.prepare_event, event_data) + await self.callbacks(transition.prepare, event_data) + if all(await self.await_all([partial(c.check, event_data) for c in transition.conditions])): + return True + except BaseException as err: + event_data.error = err + if self.on_exception: + await self.callbacks(self.on_exception, event_data) + else: + raise source_path.pop(-1) if path: with self(path.pop(0)): diff --git a/transitions/extensions/nesting.py b/transitions/extensions/nesting.py index 3935bfae..5d2832fd 100644 --- a/transitions/extensions/nesting.py +++ b/transitions/extensions/nesting.py @@ -679,20 +679,29 @@ def _can_trigger(self, model, trigger, *args, **kwargs): return self._can_trigger_nested(model, trigger, state_path, *args, **kwargs) def _can_trigger_nested(self, model, trigger, path, *args, **kwargs): - evt = NestedEventData(None, None, self, model, args, kwargs) if trigger in self.events: source_path = copy.copy(path) while source_path: + event_data = EventData(self.get_state(source_path), Event(name=trigger, machine=self), self, model, + args, kwargs) state_name = self.state_cls.separator.join(source_path) for transition in self.events[trigger].transitions.get(state_name, []): try: _ = self.get_state(transition.dest) if transition.dest is not None else transition.source except ValueError: continue - self.callbacks(self.prepare_event, evt) - self.callbacks(transition.prepare, evt) - if all(c.check(evt) for c in transition.conditions): - return True + event_data.transition = transition + try: + self.callbacks(self.prepare_event, event_data) + self.callbacks(transition.prepare, event_data) + if all(c.check(event_data) for c in transition.conditions): + return True + except BaseException as err: + event_data.error = err + if self.on_exception: + self.callbacks(self.on_exception, event_data) + else: + raise source_path.pop(-1) if path: with self(path.pop(0)): From e92fcce6f646fa2b52d60b2fcbeedc99647760fc Mon Sep 17 00:00:00 2001 From: Alexander Neumann Date: Wed, 8 May 2024 19:00:23 +0200 Subject: [PATCH 3/4] typing: state and event mandatory in EventData --- transitions/core.pyi | 4 ++-- transitions/extensions/nesting.pyi | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/transitions/core.pyi b/transitions/core.pyi index 447d48b1..60d85a41 100644 --- a/transitions/core.pyi +++ b/transitions/core.pyi @@ -68,8 +68,8 @@ class Transition: TransitionConfig = Union[Sequence[Union[str, Any]], Dict[str, Any], Transition] class EventData: - state: Optional[State] - event: Optional[Event] + state: State + event: Event machine: Machine model: object args: Iterable[Any] diff --git a/transitions/extensions/nesting.pyi b/transitions/extensions/nesting.pyi index 64f9f6aa..e2c91044 100644 --- a/transitions/extensions/nesting.pyi +++ b/transitions/extensions/nesting.pyi @@ -21,8 +21,8 @@ class NestedEvent(Event): class NestedEventData(EventData): - state: Optional[NestedState] - event: Optional[NestedEvent] + state: NestedState + event: NestedEvent machine: HierarchicalMachine transition: Optional[NestedTransition] source_name: Optional[str] From de04fd9744fa7e5c1fd23db9ff5e33e13412c28c Mon Sep 17 00:00:00 2001 From: Alexander Neumann Date: Wed, 8 May 2024 19:00:28 +0200 Subject: [PATCH 4/4] update changelog --- Changelog.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Changelog.md b/Changelog.md index 97624169..cd0bf162 100644 --- a/Changelog.md +++ b/Changelog.md @@ -9,8 +9,9 @@ - '_anchor' suffix has been removed for (py)graphviz cluster node anchors - local testing switched from [tox](https://github.com/tox-dev/tox) to [nox](https://github.com/wntrblm/nox) - PR #633: Remove surrounding whitespace from docstrings (thanks @artofhuman) +- Bug #626: Process exceptions with `Machine.on_exception` in may_ as well (thanks @match1) - Typing: - + Made machine property mandatory in (Nested)EventData + + Made state, event and machine property mandatory in (Nested)EventData + Transition.dest may be None ## 0.9.0 (September 2022)