Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing #626: Exceptions were raised with may_<trigger> even though on_exception is provided #663

Merged
merged 4 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
- '_anchor' suffix has been removed for (py)graphviz cluster node anchors
- local testing switched from [tox](https://github.com/tox-dev/tox) to [nox](https://github.com/wntrblm/nox)
- PR #633: Remove surrounding whitespace from docstrings (thanks @artofhuman)
- Bug #626: Process exceptions with `Machine.on_exception` in may_<trigger> as well (thanks @match1)
- Typing:
+ Made machine property mandatory in (Nested)EventData
+ Made state, event and machine property mandatory in (Nested)EventData
+ Transition.dest may be None

## 0.9.0 (September 2022)
Expand Down
31 changes: 30 additions & 1 deletion tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from functools import partial
import weakref
from .test_core import TestTransitions, MachineError, TYPE_CHECKING
from .utils import DummyModel
from .utils import DummyModel, Stuff
from .test_graphviz import pgv as gv
from .test_pygraphviz import pgv

Expand Down Expand Up @@ -513,6 +513,35 @@ async def run():

asyncio.run(run())

def test_may_transition_with_exception(self):

stuff = Stuff(machine_cls=self.machine_cls, extra_kwargs={"send_event": True})
stuff.machine.add_transition(trigger="raises", source="A", dest="B", prepare=partial(stuff.this_raises, RuntimeError("Prepare Exception")))
stuff.machine.add_transition(trigger="raises", source="B", dest="C", conditions=partial(stuff.this_raises, ValueError("Condition Exception")))
stuff.machine.add_transition(trigger="works", source="A", dest="B")

def process_exception(event_data):
assert event_data.error is not None
assert event_data.transition is not None
assert event_data.event.name == "raises"
assert event_data.machine == stuff.machine

async def run():
with self.assertRaises(RuntimeError):
await stuff.may_raises()
assert stuff.is_A()
assert await stuff.may_works()
assert await stuff.works()
with self.assertRaises(ValueError):
await stuff.may_raises()
assert stuff.is_B()
stuff.machine.on_exception.append(process_exception)
assert not await stuff.may_raises()
assert await stuff.to_A()
assert not await stuff.may_raises()

asyncio.run(run())


@skipIf(asyncio is None or (pgv is None and gv is None), "AsyncGraphMachine requires asyncio and (py)gaphviz")
class TestAsyncGraphMachine(TestAsync):
Expand Down
25 changes: 25 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1294,3 +1294,28 @@ def test_may_transition_with_invalid_state(self):
m = self.machine_cls(model=d, states=states, initial='A', auto_transitions=False)
m.add_transition('walk', 'A', 'UNKNOWN')
assert not d.may_walk()

def test_may_transition_with_exception(self):
stuff = Stuff(machine_cls=self.machine_cls, extra_kwargs={"send_event": True})
stuff.machine.add_transition(trigger="raises", source="A", dest="B", prepare=partial(stuff.this_raises, RuntimeError("Prepare Exception")))
stuff.machine.add_transition(trigger="raises", source="B", dest="C", conditions=partial(stuff.this_raises, ValueError("Condition Exception")))
stuff.machine.add_transition(trigger="works", source="A", dest="B")

def process_exception(event_data):
assert event_data.error is not None
assert event_data.transition is not None
assert event_data.event.name == "raises"
assert event_data.machine == stuff.machine

with self.assertRaises(RuntimeError):
stuff.may_raises()
assert stuff.is_A()
assert stuff.may_works()
assert stuff.works()
with self.assertRaises(ValueError):
stuff.may_raises()
assert stuff.is_B()
stuff.machine.on_exception.append(process_exception)
assert not stuff.may_raises()
assert stuff.to_A()
assert not stuff.may_raises()
23 changes: 15 additions & 8 deletions transitions/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -876,23 +876,30 @@ def _checked_assignment(self, model, name, func):
setattr(model, name, func)

def _can_trigger(self, model, trigger, *args, **kwargs):
evt = EventData(None, None, self, model, args, kwargs)
state = self.get_model_state(model).name
state = self.get_model_state(model)
event_data = EventData(state, Event(name=trigger, machine=self), self, model, args, kwargs)

for trigger_name in self.get_triggers(state):
if trigger_name != trigger:
continue
for transition in self.events[trigger_name].transitions[state]:
for transition in self.events[trigger_name].transitions[state.name]:
try:
_ = self.get_state(transition.dest) if transition.dest is not None else transition.source
except ValueError:
continue

evt.transition = transition
self.callbacks(self.prepare_event, evt)
self.callbacks(transition.prepare, evt)
if all(c.check(evt) for c in transition.conditions):
return True
event_data.transition = transition
try:
self.callbacks(self.prepare_event, event_data)
self.callbacks(transition.prepare, event_data)
if all(c.check(event_data) for c in transition.conditions):
return True
except BaseException as err:
event_data.error = err
if self.on_exception:
self.callbacks(self.on_exception, event_data)
else:
raise
return False

def _add_may_transition_func_for_trigger(self, trigger, model):
Expand Down
4 changes: 2 additions & 2 deletions transitions/core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ class Transition:
TransitionConfig = Union[Sequence[Union[str, Any]], Dict[str, Any], Transition]

class EventData:
state: Optional[State]
event: Optional[Event]
state: State
event: Event
machine: Machine
model: object
args: Iterable[Any]
Expand Down
41 changes: 29 additions & 12 deletions transitions/extensions/asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,21 +425,29 @@ def remove_model(self, model):
self._transition_queue.extend(new_queue)

async def _can_trigger(self, model, trigger, *args, **kwargs):
evt = AsyncEventData(None, None, self, model, args, kwargs)
state = self.get_model_state(model).name
state = self.get_model_state(model)
event_data = AsyncEventData(state, AsyncEvent(name=trigger, machine=self), self, model, args, kwargs)

for trigger_name in self.get_triggers(state):
if trigger_name != trigger:
continue
for transition in self.events[trigger_name].transitions[state]:
for transition in self.events[trigger_name].transitions[state.name]:
try:
_ = self.get_state(transition.dest) if transition.dest is not None else transition.source
except ValueError:
continue
await self.callbacks(self.prepare_event, evt)
await self.callbacks(transition.prepare, evt)
if all(await self.await_all([partial(c.check, evt) for c in transition.conditions])):
return True
event_data.transition = transition
try:
await self.callbacks(self.prepare_event, event_data)
await self.callbacks(transition.prepare, event_data)
if all(await self.await_all([partial(c.check, event_data) for c in transition.conditions])):
return True
except BaseException as err:
event_data.error = err
if self.on_exception:
await self.callbacks(self.on_exception, event_data)
else:
raise
return False

def _process(self, trigger):
Expand Down Expand Up @@ -552,20 +560,29 @@ async def _can_trigger(self, model, trigger, *args, **kwargs):
return await self._can_trigger_nested(model, trigger, state_path, *args, **kwargs)

async def _can_trigger_nested(self, model, trigger, path, *args, **kwargs):
evt = AsyncEventData(None, None, self, model, args, kwargs)
if trigger in self.events:
source_path = copy.copy(path)
while source_path:
event_data = AsyncEventData(self.get_state(source_path), AsyncEvent(name=trigger, machine=self), self,
model, args, kwargs)
state_name = self.state_cls.separator.join(source_path)
for transition in self.events[trigger].transitions.get(state_name, []):
try:
_ = self.get_state(transition.dest) if transition.dest is not None else transition.source
except ValueError:
continue
await self.callbacks(self.prepare_event, evt)
await self.callbacks(transition.prepare, evt)
if all(await self.await_all([partial(c.check, evt) for c in transition.conditions])):
return True
event_data.transition = transition
try:
await self.callbacks(self.prepare_event, event_data)
await self.callbacks(transition.prepare, event_data)
if all(await self.await_all([partial(c.check, event_data) for c in transition.conditions])):
return True
except BaseException as err:
event_data.error = err
if self.on_exception:
await self.callbacks(self.on_exception, event_data)
else:
raise
source_path.pop(-1)
if path:
with self(path.pop(0)):
Expand Down
19 changes: 14 additions & 5 deletions transitions/extensions/nesting.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,20 +679,29 @@ def _can_trigger(self, model, trigger, *args, **kwargs):
return self._can_trigger_nested(model, trigger, state_path, *args, **kwargs)

def _can_trigger_nested(self, model, trigger, path, *args, **kwargs):
evt = NestedEventData(None, None, self, model, args, kwargs)
if trigger in self.events:
source_path = copy.copy(path)
while source_path:
event_data = EventData(self.get_state(source_path), Event(name=trigger, machine=self), self, model,
args, kwargs)
state_name = self.state_cls.separator.join(source_path)
for transition in self.events[trigger].transitions.get(state_name, []):
try:
_ = self.get_state(transition.dest) if transition.dest is not None else transition.source
except ValueError:
continue
self.callbacks(self.prepare_event, evt)
self.callbacks(transition.prepare, evt)
if all(c.check(evt) for c in transition.conditions):
return True
event_data.transition = transition
try:
self.callbacks(self.prepare_event, event_data)
self.callbacks(transition.prepare, event_data)
if all(c.check(event_data) for c in transition.conditions):
return True
except BaseException as err:
event_data.error = err
if self.on_exception:
self.callbacks(self.on_exception, event_data)
else:
raise
source_path.pop(-1)
if path:
with self(path.pop(0)):
Expand Down
4 changes: 2 additions & 2 deletions transitions/extensions/nesting.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ class NestedEvent(Event):


class NestedEventData(EventData):
state: Optional[NestedState]
event: Optional[NestedEvent]
state: NestedState
event: NestedEvent
machine: HierarchicalMachine
transition: Optional[NestedTransition]
source_name: Optional[str]
Expand Down