diff --git a/finite_state_machine/draw_state_diagram.py b/finite_state_machine/draw_state_diagram.py index 0d12c19..2f02ae1 100644 --- a/finite_state_machine/draw_state_diagram.py +++ b/finite_state_machine/draw_state_diagram.py @@ -26,7 +26,6 @@ def generate_state_diagram_markdown(cls, initial_state): https://mermaid-js.github.io/mermaid/diagrams-and-syntax-and-examples/stateDiagram.html """ - class_fns = inspect.getmembers(cls, predicate=inspect.isfunction) state_transitions: List[Transition] = [ func.__fsm for name, func in class_fns if hasattr(func, "__fsm") diff --git a/finite_state_machine/state_machine.py b/finite_state_machine/state_machine.py index ede63e6..2bd2704 100644 --- a/finite_state_machine/state_machine.py +++ b/finite_state_machine/state_machine.py @@ -1,6 +1,7 @@ import functools import types from typing import NamedTuple, Union +import inspect from .exceptions import ConditionsNotMet, InvalidStartState @@ -45,7 +46,38 @@ def transition(source, target, conditions=None, on_error=None): raise ValueError("on_error needs to be a bool, int or string") def transition_decorator(func): - func.__fsm = Transition(func.__name__, source, target, conditions, on_error) + mems = inspect.getmembers(func) + state_machine_instance = [ + mem[1]["StateMachine"] for mem in mems if mem[0] == "__globals__" + ][0] + func.__fsm = Transition( + name=func.__name__, + source=source, + target=target, + conditions=conditions, + on_error=on_error, + ) + + # creating and/or adding items to __fsm attribute + if hasattr(state_machine_instance, "__fsm"): + if isinstance(source, list): + for src in source: + if src in state_machine_instance.__fsm: + state_machine_instance.__fsm[src].append(target) + else: + state_machine_instance.__fsm[src] = [target] + else: + if source in state_machine_instance.__fsm: + state_machine_instance.__fsm[src].append(target) + else: + state_machine_instance.__fsm[src] = [target] + else: + if isinstance(source, list): + state_machine_instance.__fsm = {} + for src in source: + state_machine_instance.__fsm[src] = [target] + else: + state_machine_instance.__fsm.source = [target] @functools.wraps(func) def _wrapper(*args, **kwargs):