Skip to content

Commit

Permalink
feat: enhance mermaid graph rendering with unique node IDs and subgra…
Browse files Browse the repository at this point in the history
…ph support

chore: add tests for most functionality in web_interface.py
  • Loading branch information
provos committed Jan 27, 2025
1 parent 96fc68c commit 51af427
Show file tree
Hide file tree
Showing 2 changed files with 383 additions and 6 deletions.
68 changes: 62 additions & 6 deletions src/planai/web_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,21 +184,77 @@ def render_mermaid_graph(graph: "Graph"):
# Start with graph definition
mermaid = """graph TD\n"""

# Track unique IDs for nodes and subgraphs
unique_id = 0
mapping = {}
subgraph_workers = []

# Helper function to get or create unique ID for a worker
def get_worker_id(worker):
nonlocal unique_id
if worker not in mapping:
unique_id += 1
mapping[worker] = f"task_{unique_id}"
return mapping[worker]

# Helper function to create node definition
def create_node(worker):
worker_id = get_worker_id(worker)
return f"{worker_id}[{worker.name.replace(' ', '_')}]"

def is_subgraph(worker):
if not hasattr(worker, "graph"):
return False
if not isinstance(worker.graph, object):
return False
return hasattr(worker.graph, "dependencies")

# First pass: identify subgraphs and create all node mappings
for upstream, downstream_list in graph.dependencies.items():
get_worker_id(upstream)
for downstream in downstream_list:
get_worker_id(downstream)
# Check if either worker is a SubGraphWorker
if is_subgraph(upstream):
if upstream not in subgraph_workers:
subgraph_workers.append(upstream)
if is_subgraph(downstream):
if downstream not in subgraph_workers:
subgraph_workers.append(downstream)

# Add subgraphs first
for worker in subgraph_workers:
subgraph_id = get_worker_id(worker)
mermaid += f" subgraph {subgraph_id}[{worker.name}]\n"

# Add nodes and edges within the subgraph
for sub_upstream, sub_downstream_list in worker.graph.dependencies.items():
sub_upstream_id = get_worker_id(sub_upstream)
mermaid += f" {create_node(sub_upstream)}\n"
for sub_downstream in sub_downstream_list:
sub_downstream_id = get_worker_id(sub_downstream)
mermaid += f" {create_node(sub_downstream)}\n"
mermaid += f" {sub_upstream_id}-->{sub_downstream_id}\n"

mermaid += " end\n"

# Add all edges from dependencies
for upstream, downstream_list in graph.dependencies.items():
upstream_id = get_worker_id(upstream)
for downstream in downstream_list:
# Create an edge for each dependency using worker names
# Sanitize names by replacing spaces with underscores
src = upstream.name.replace(" ", "_")
dst = downstream.name.replace(" ", "_")
mermaid += f" {src}-->{dst}\n"
downstream_id = get_worker_id(downstream)
mermaid += f" {upstream_id}[{upstream.name}]-->{downstream_id}[{downstream.name}]\n"

return mermaid


def run_web_interface(disp: "Dispatcher", port=5000):
def set_dispatcher(disp: "Dispatcher"):
global dispatcher
dispatcher = disp


def run_web_interface(disp: "Dispatcher", port=5000):
set_dispatcher(disp)
app.run(debug=False, use_reloader=False, port=port)


Expand Down
Loading

0 comments on commit 51af427

Please sign in to comment.