diff --git a/bluecellulab/graph.py b/bluecellulab/graph.py index a65cfc66..7afd0502 100644 --- a/bluecellulab/graph.py +++ b/bluecellulab/graph.py @@ -42,7 +42,7 @@ def plot_graph(G: nx.Graph, node_size: float = 400, edge_width: float = 0.4, nod populations = list(set([cell_id.population_name for cell_id in G.nodes()])) # Create a color map for each population - color_map = plt.cm.tab20(np.linspace(0, 1, len(populations))) + color_map = plt.cm.tab20(np.linspace(0, 1, len(populations))) # type: ignore[attr-defined] population_color = dict(zip(populations, color_map)) # Create node colors based on their population @@ -50,7 +50,7 @@ def plot_graph(G: nx.Graph, node_size: float = 400, edge_width: float = 0.4, nod # Extract weights for edge color mapping edge_weights = [d['weight'] for _, _, d in G.edges(data=True)] - edge_colors = plt.cm.Greens(np.interp(edge_weights, (min(edge_weights), max(edge_weights)), (0.3, 1))) + edge_colors = plt.cm.Greens(np.interp(edge_weights, (min(edge_weights), max(edge_weights)), (0.3, 1))) # type: ignore[attr-defined] # Create positions using spring layout for the entire graph pos = nx.spring_layout(G, k=node_distance)