Skip to content

Commit

Permalink
docs: Update (Async)TransitionConfig(Dict) and adjusted tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aleneum committed Aug 13, 2024
1 parent b9e4011 commit 80ae05a
Show file tree
Hide file tree
Showing 19 changed files with 206 additions and 84 deletions.
1 change: 1 addition & 0 deletions Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

- Bug #683: Typing wrongly suggested that `Transition` instances can be passed to `Machine.__init__` and/or `Machine.add_transition(s)` (thanks @antonio-antuan)
- Typing should be more precise now
- Made `transitions.core.(Async)TransitionConfigDict` a `TypedDict` which can be used to spot parameter errors during static analysis
- `Machine.add_transitions` and `Machine.__init__` expect a `Sequence` of configurations for transitions now
- Added 'async' callbacks to types in `asyncio` extension

Expand Down
50 changes: 48 additions & 2 deletions tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

try:
import asyncio
from transitions.extensions.asyncio import AsyncMachine, HierarchicalAsyncMachine, AsyncEventData
from transitions.extensions.asyncio import AsyncMachine, HierarchicalAsyncMachine, AsyncEventData, \
AsyncTransition

except (ImportError, SyntaxError):
asyncio = None # type: ignore
Expand All @@ -17,7 +18,7 @@
from .test_pygraphviz import pgv

if TYPE_CHECKING:
from typing import Type, Sequence
from typing import Type, Sequence, List
from transitions.extensions.asyncio import AsyncTransitionConfig


Expand Down Expand Up @@ -584,6 +585,51 @@ async def run():

asyncio.run(run())

def test_custom_transition(self):

class MyTransition(self.machine_cls.transition_cls): # type: ignore

def __init__(self, source, dest, conditions=None, unless=None, before=None,
after=None, prepare=None, my_int=None, my_none=None, my_str=None, my_dict=None):
super(MyTransition, self).__init__(source, dest, conditions, unless, before, after, prepare)
self.my_int = my_int
self.my_none = my_none
self.my_str = my_str
self.my_dict = my_dict

class MyMachine(self.machine_cls): # type: ignore
transition_cls = MyTransition

a_transition = {
"trigger": "go", "source": "B", "dest": "A",
"my_int": 42, "my_str": "foo", "my_dict": {"bar": "baz"}
}
transitions = [
["go", "A", "B"],
a_transition
]

m = MyMachine(states=["A", "B"], transitions=transitions, initial="A")
m.add_transition("reset", "*", "A",
my_int=23, my_str="foo2", my_none=None, my_dict={"baz": "bar"})

async def run():
assert await m.go()
trans = m.get_transitions("go", "B") # type: List[MyTransition]
assert len(trans) == 1
assert trans[0].my_str == a_transition["my_str"]
assert trans[0].my_int == a_transition["my_int"]
assert trans[0].my_dict == a_transition["my_dict"]
assert trans[0].my_none is None
trans = m.get_transitions("reset", "A")
assert len(trans) == 1
assert trans[0].my_str == "foo2"
assert trans[0].my_int == 23
assert trans[0].my_dict == {"baz": "bar"}
assert trans[0].my_none is None

asyncio.run(run())


@skipIf(asyncio is None or (pgv is None and gv is None), "AsyncGraphMachine requires asyncio and (py)gaphviz")
class TestAsyncGraphMachine(TestAsync):
Expand Down
49 changes: 45 additions & 4 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
pass

import sys
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, List
from functools import partial
from unittest import TestCase, skipIf
import weakref

from transitions import Machine, MachineError, State, EventData
from transitions.core import listify, _prep_ordered_arg
from transitions.core import listify, _prep_ordered_arg, Transition

from .utils import InheritedStuff
from .utils import Stuff, DummyModel
Expand All @@ -23,7 +23,7 @@

if TYPE_CHECKING:
from typing import Sequence
from transitions.core import TransitionConfig, StateConfig
from transitions.core import TransitionConfig, StateConfig, TransitionConfigDict


def on_exit_A(event):
Expand Down Expand Up @@ -570,7 +570,7 @@ def test_pickle(self):
{'trigger': 'walk', 'source': 'A', 'dest': 'B'},
{'trigger': 'run', 'source': 'B', 'dest': 'C'},
{'trigger': 'sprint', 'source': 'C', 'dest': 'D'}
]
] # type: Sequence[TransitionConfigDict]
m = Machine(states=states, transitions=transitions, initial='A')
m.walk()
dump = pickle.dumps(m)
Expand Down Expand Up @@ -1350,3 +1350,44 @@ def test_on_final(self):
self.assertEqual(1, final_mock.call_count)
machine.to_B()
self.assertEqual(2, final_mock.call_count)

def test_custom_transition(self):

class MyTransition(self.machine_cls.transition_cls): # type: ignore

def __init__(self, source, dest, conditions=None, unless=None, before=None,
after=None, prepare=None, my_int=None, my_none=None, my_str=None, my_dict=None):
super(MyTransition, self).__init__(source, dest, conditions, unless, before, after, prepare)
self.my_int = my_int
self.my_none = my_none
self.my_str = my_str
self.my_dict = my_dict

class MyMachine(self.machine_cls): # type: ignore
transition_cls = MyTransition

a_transition = {
"trigger": "go", "source": "B", "dest": "A",
"my_int": 42, "my_str": "foo", "my_dict": {"bar": "baz"}
}
transitions = [
["go", "A", "B"],
a_transition
]

m = MyMachine(states=["A", "B"], transitions=transitions, initial="A")
m.add_transition("reset", "*", "A",
my_int=23, my_str="foo2", my_none=None, my_dict={"baz": "bar"})
assert m.go()
trans = m.get_transitions("go", "B") # type: List[MyTransition]
assert len(trans) == 1
assert trans[0].my_str == a_transition["my_str"]
assert trans[0].my_int == a_transition["my_int"]
assert trans[0].my_dict == a_transition["my_dict"]
assert trans[0].my_none is None
trans = m.get_transitions("reset", "A")
assert len(trans) == 1
assert trans[0].my_str == "foo2"
assert trans[0].my_int == 23
assert trans[0].my_dict == {"baz": "bar"}
assert trans[0].my_none is None
7 changes: 2 additions & 5 deletions tests/test_experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,8 @@ class Model:
def is_B(self) -> bool:
return False

transition_config = [["A", "B"], "C"] # type: TransitionConfig

@add_transitions(transition(source="A", dest="B"))
@add_transitions(transition_config)
@add_transitions([["A", "B"], "C"])
def go(self) -> bool:
raise RuntimeError("Should be overridden!")

Expand Down Expand Up @@ -196,8 +194,7 @@ class Model:
def is_B(self) -> bool:
return False

transition_config = [["A", "B"], "C"] # type: TransitionConfig
go = event(transition(source="A", dest="B"), transition_config)
go = event(transition(source="A", dest="B"), [["A", "B"], "C"], {"source": "*", "dest": None})

model = Model()
machine = self.trigger_machine(model, states=["A", "B", "C"], initial="A")
Expand Down
6 changes: 3 additions & 3 deletions tests/test_graphviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

if TYPE_CHECKING:
from typing import Type, List, Collection, Union, Literal, Sequence, Dict, Optional
from transitions.core import TransitionConfig
from transitions.core import TransitionConfig, TransitionConfigDict


class TestDiagramsImport(TestCase):
Expand Down Expand Up @@ -75,7 +75,7 @@ def setUp(self):
{'trigger': 'run', 'source': 'B', 'dest': 'C'},
{'trigger': 'sprint', 'source': 'C', 'dest': 'D', 'conditions': 'is_fast'},
{'trigger': 'sprint', 'source': 'C', 'dest': 'B'}
] # type: Sequence[Dict[str, str]]
] # type: Sequence[TransitionConfigDict]

def test_diagram(self):
m = self.machine_cls(states=self.states, transitions=self.transitions, initial='A', auto_transitions=False,
Expand Down Expand Up @@ -327,7 +327,7 @@ def setUp(self):
'conditions': 'is_fast'},
{'trigger': 'sprint', 'source': 'C', 'dest': 'B'}, # + 1 edge
{'trigger': 'reset', 'source': '*', 'dest': 'A'} # + 4 edges (from base state) = 8
] # type: Sequence[Dict[str, str]]
] # type: Sequence[TransitionConfigDict]

def test_diagram(self):
m = self.machine_cls(states=self.states, transitions=self.transitions, initial='A', auto_transitions=False,
Expand Down
11 changes: 5 additions & 6 deletions tests/test_reuse.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@

if TYPE_CHECKING:
from typing import List, Union, Dict, Any, Sequence
from transitions.core import TransitionConfig

from transitions.core import TransitionConfig, TransitionConfigDict

test_states = ['A', 'B', {'name': 'C', 'children': ['1', '2', {'name': '3', 'children': ['a', 'b', 'c']}]},
'D', 'E', 'F']
Expand Down Expand Up @@ -91,7 +90,7 @@ def test_blueprint_reuse(self):
{'trigger': 'decrease', 'source': '3', 'dest': '2'},
{'trigger': 'decrease', 'source': '1', 'dest': '1'},
{'trigger': 'reset', 'source': '*', 'dest': '1'},
]
] # type: Sequence[TransitionConfigDict]

counter = self.machine_cls(states=states, transitions=transitions, before_state_change='check',
after_state_change='clear', initial='1')
Expand All @@ -103,7 +102,7 @@ def test_blueprint_reuse(self):
{'trigger': 'backward', 'source': 'C', 'dest': 'B'},
{'trigger': 'backward', 'source': 'B', 'dest': 'A'},
{'trigger': 'calc', 'source': '*', 'dest': 'C'},
]
] # type: Sequence[TransitionConfigDict]

walker = self.machine_cls(states=new_states, transitions=new_transitions, before_state_change='watch',
after_state_change='look_back', initial='A')
Expand Down Expand Up @@ -144,7 +143,7 @@ def test_blueprint_remap(self):
{'trigger': 'decrease', 'source': '1', 'dest': '1'},
{'trigger': 'reset', 'source': '*', 'dest': '1'},
{'trigger': 'done', 'source': '3', 'dest': 'finished'}
]
] # type: Sequence[TransitionConfigDict]

counter = self.machine_cls(states=states, transitions=transitions, initial='1')

Expand All @@ -158,7 +157,7 @@ def test_blueprint_remap(self):
{'trigger': 'backward', 'source': 'C', 'dest': 'B'},
{'trigger': 'backward', 'source': 'B', 'dest': 'A'},
{'trigger': 'calc', 'source': '*', 'dest': 'C%s1' % State.separator},
]
] # type: Sequence[TransitionConfigDict]

walker = self.machine_cls(states=new_states, transitions=new_transitions, before_state_change='watch',
after_state_change='look_back', initial='A')
Expand Down
3 changes: 2 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ class Stuff(object):
is_false = False
is_True = True

def __init__(self, states=None, machine_cls=Machine, extra_kwargs={}):
def __init__(self, states=None, machine_cls=Machine, extra_kwargs=None):
extra_kwargs = extra_kwargs if extra_kwargs is not None else {}

self.state = None
self.message = None
Expand Down
33 changes: 22 additions & 11 deletions transitions/core.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from logging import Logger
from typing import (
Any, Optional, Callable, Sequence, Union, Iterable, List, Dict, DefaultDict,
Type, Deque, OrderedDict, Tuple, Literal, Collection, TypedDict, Mapping
Type, Deque, OrderedDict, Tuple, Literal, Collection, TypedDict, Mapping, Required
)

# Enums are supported for Python 3.4+ and Python 2.7 with enum34 package installed
Expand Down Expand Up @@ -91,8 +91,19 @@ TransitionConfigList = Union[
List[str], List[Sequence[str]], List[Optional[str]],
List[Union[str, Enum]], List[Optional[Union[str, Enum]]]
]
TransitionConfigDict = Mapping[str, Union[None, StateConfig, Callback, Iterable[Callback]]]
TransitionConfig = Union[TransitionConfigList, TransitionConfigDict]

class TransitionConfigDict(TypedDict, total=False):
trigger: Required[str]
source: Required[Union[str, Enum, Sequence[Union[str, Enum]]]]
dest: Required[Optional[Union[str, Enum]]]
prepare: CallbacksArg
before: CallbacksArg
after: CallbacksArg
conditions: CallbacksArg
unless: CallbacksArg

# For backwards compatibility we also accept untyped dictionaries/mappings
TransitionConfig = Union[TransitionConfigList, TransitionConfigDict, Mapping[str, Any]]

class EventData:
state: State
Expand All @@ -115,7 +126,7 @@ class Event:
transitions: DefaultDict[str, List[Transition]]
def __init__(self, name: str, machine: Machine) -> None: ...
def add_transition(self, transition: Transition) -> None: ...
def trigger(self, model: object, *args: List[Any], **kwargs: Dict[str, Any]) -> bool: ...
def trigger(self, model: object, *args: Any, **kwargs: Any) -> bool: ...
def _trigger(self, event_data: EventData) -> bool: ...
def _process(self, event_data: EventData) -> bool: ...
def _is_valid_source(self, state: State) -> bool: ...
Expand Down Expand Up @@ -157,7 +168,7 @@ class Machine:
name: str = ..., queued: bool = ...,
prepare_event: CallbacksArg = ..., finalize_event: CallbacksArg = ...,
model_attribute: str = ..., model_override: bool = ...,
on_exception: CallbacksArg = ..., on_final: CallbacksArg = ..., **kwargs: Dict[str, Any]) -> None: ...
on_exception: CallbacksArg = ..., on_final: CallbacksArg = ..., **kwargs: Any) -> None: ...
def add_model(self, model: ModelParameter,
initial: Optional[StateIdentifier] = ...) -> None: ...
def remove_model(self, model: ModelParameter) -> None: ...
Expand Down Expand Up @@ -201,21 +212,21 @@ class Machine:
def set_state(self, state: StateIdentifier, model: Optional[object] = ...) -> None: ...
def add_state(self, states: Union[Sequence[StateConfig], StateConfig],
on_enter: CallbacksArg = ..., on_exit: CallbacksArg = ...,
ignore_invalid_triggers: Optional[bool] = ..., **kwargs: Dict[str, Any]) -> None: ...
ignore_invalid_triggers: Optional[bool] = ..., **kwargs: Any) -> None: ...
def add_states(self, states: Union[Sequence[StateConfig], StateConfig],
on_enter: CallbacksArg = ..., on_exit: CallbacksArg = ...,
ignore_invalid_triggers: Optional[bool] = ..., **kwargs: Dict[str, Any]) -> None: ...
ignore_invalid_triggers: Optional[bool] = ..., **kwargs: Any) -> None: ...
def _add_model_to_state(self, state: State, model: object) -> None: ...
def _checked_assignment(self, model: object, name: str, func: CallbackFunc) -> None: ...
def _add_trigger_to_model(self, trigger: str, model: object) -> None: ...
def _get_trigger(self, model: object, trigger_name: str, *args: List[Any], **kwargs: Dict[str, Any]) -> bool: ...
def _get_trigger(self, model: object, trigger_name: str, *args: Any, **kwargs: Any) -> bool: ...
def get_triggers(self, *args: Union[str, Enum, State]) -> List[str]: ...
def add_transition(self, trigger: str,
source: Union[StateIdentifier, List[StateIdentifier]],
dest: Optional[StateIdentifier] = ...,
conditions: CallbacksArg = ..., unless: CallbacksArg = ...,
before: CallbacksArg = ..., after: CallbacksArg = ..., prepare: CallbacksArg = ...,
**kwargs: Dict[str, Any]) -> None: ...
**kwargs: Any) -> None: ...
def add_transitions(self, transitions: Sequence[TransitionConfig]) -> None: ...
def add_ordered_transitions(self, states: Optional[Sequence[Union[str, State]]] = ...,
trigger: str = ..., loop: bool = ...,
Expand All @@ -225,11 +236,11 @@ class Machine:
before: Optional[Sequence[Union[Callback, None]]] = ...,
after: Optional[Sequence[Union[Callback, None]]] = ...,
prepare: CallbacksArg = ...,
**kwargs: Dict[str, Any]) -> None: ...
**kwargs: Any) -> None: ...
def get_transitions(self, trigger: str = ...,
source: StateIdentifier = ..., dest: StateIdentifier = ...) -> List[Transition]: ...
def remove_transition(self, trigger: str, source: str = ..., dest: str = ...) -> None: ...
def dispatch(self, trigger: str, *args: List[Any], **kwargs: Dict[str, Any]) -> bool: ...
def dispatch(self, trigger: str, *args: Any, **kwargs: Any) -> bool: ...
def callbacks(self, funcs: Iterable[Callback], event_data: EventData) -> None: ...
def callback(self, func: Callback, event_data: EventData) -> None: ...
@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions transitions/experimental/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def generate_base_model(config):
f" def may_{trigger_name}(self) -> bool: {_placeholder_body}\n"
)

extra_params = "event_data: EventData" if m.send_event else "*args: List[Any], **kwargs: Dict[str, Any]"
extra_params = "event_data: EventData" if m.send_event else "*args: Any, **kwargs: Any"
for callback_name in callbacks:
if isinstance(callback_name, str):
callback_block += (f" @abstractmethod\n"
Expand Down Expand Up @@ -98,7 +98,7 @@ def add_model_override(self, model, initial=None):
self.model_override = True
for model in listify(model):
model = self if model == "self" else model
for name, specs in TriggerPlaceholder.definitions.get(model.__class__).items():
for name, specs in TriggerPlaceholder.definitions.get(model.__class__, {}).items():
for spec in specs:
if isinstance(spec, list):
self.add_transition(name, *spec)
Expand Down
Loading

0 comments on commit 80ae05a

Please sign in to comment.