Skip to content

Commit

Permalink
added: heatmap for jaccard distances (assumed comparison b/ networks)…
Browse files Browse the repository at this point in the history
…, and lollipop from df on arbitrary fields
  • Loading branch information
ivanovaos committed Jul 3, 2024
1 parent f585b2f commit cc9a726
Showing 1 changed file with 188 additions and 52 deletions.
240 changes: 188 additions & 52 deletions networkcommons/visual/_network_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,88 @@

from __future__ import annotations

__all__ = ['plot_n_nodes_edges']
__all__ = ['plot_n_nodes_edges', 'plot_n_nodes_edges_from_df', 'build_heatmap_with_tree']

from typing import List, Dict
from typing import List, Dict, Union

import lazy_import
import pandas as pd
plt = lazy_import.lazy_import('matplotlib.pyplot')
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.cluster.hierarchy import linkage
from scipy.spatial.distance import squareform


def lollipop_plot(
labels: List[str],
values: List[List[int]],
categories: List[List[str]],
orientation: str,
color_palette: str,
size: int,
linewidth: int,
marker: str,
title: str,
filepath: str = None,
render: bool = False
):
"""
Common function to plot metrics using a lollipop plot.
Args:
labels (List[str]): Labels for the x or y axis.
values (List[List[int]]): List of values to plot.
categories (List[List[str]]): List of categories (e.g., 'Nodes', 'Edges') for each label.
orientation (str): 'vertical' or 'horizontal'.
color_palette (str): Matplotlib color palette.
size (int): Size of the markers.
linewidth (int): Line width of the lollipops.
marker (str): Marker style for the lollipops.
title (str): Title of the plot.
filepath (str): Path to save the plot. Default is None.
render (bool): Whether to display the plot. Default is False.
"""
palette = plt.get_cmap(color_palette)
colors = palette.colors if hasattr(palette, 'colors') else palette(range(len(labels)))

fig, ax = plt.subplots(figsize=(12, 8))

for idx, (label, value_set, category_set) in enumerate(zip(labels, values, categories)):
color = colors[idx % len(colors)]

if orientation == 'vertical':
positions = [f"{label} {cat}" for cat in category_set]
ax.vlines(x=positions, ymin=0, ymax=value_set, color=color, linewidth=linewidth, label=label)
ax.scatter(positions, value_set, color=color, s=size ** 2, marker=marker, zorder=3)

for i, value in enumerate(value_set):
offset = size * 0.1 if value < 10 else size * 0.2
ax.text(positions[i], value + offset, str(value), ha='center', va='bottom', fontsize=size)
else:
positions = [f"{label} {cat}" for cat in category_set]
ax.hlines(y=positions, xmin=0, xmax=value_set, color=color, linewidth=linewidth, label=label)
ax.scatter(value_set, positions, color=color, s=size ** 2, marker=marker, zorder=3)

for i, value in enumerate(value_set):
offset = size * 0.1 if value < 10 else size * 0.2
ax.text(value + offset, positions[i], str(value), va='center', ha='left', fontsize=size)

if orientation == 'vertical':
ax.set_xlabel("Network and Type")
ax.set_ylabel("Count")
else:
ax.set_ylabel("Network and Type")
ax.set_xlabel("Count")

ax.set_title(title)
ax.legend()

if filepath:
plt.savefig(filepath)

if render:
plt.show()


def plot_n_nodes_edges(
Expand Down Expand Up @@ -59,67 +133,129 @@ def plot_n_nodes_edges(
if not show_nodes and not show_edges:
raise ValueError("At least one of 'show_nodes' or 'show_edges' must be True.")

# Set the color palette
palette = plt.get_cmap(color_palette)
colors = palette.colors if hasattr(palette, 'colors') else palette(range(len(networks)))
labels = []
values = []
categories = []

fig, ax = plt.subplots(figsize=(12, 8))

for idx, (network_name, network) in enumerate(networks.items()):
# Get the number of nodes and edges
for network_name, network in networks.items():
n_nodes = len(network.nodes)
n_edges = len(network.edges)
categories = []
values = []
value_set = []
category_set = []

if show_nodes:
categories.append('Nodes')
values.append(n_nodes)
category_set.append('Nodes')
value_set.append(n_nodes)
if show_edges:
categories.append('Edges')
values.append(n_edges)

color = colors[idx % len(colors)]

if orientation == 'vertical':
positions = [f"{network_name} {cat}" for cat in categories]
ax.vlines(x=positions, ymin=0, ymax=values, color=color, linewidth=linewidth, label=network_name)
ax.scatter(positions, values, color=color, s=size ** 2, marker=marker, zorder=3)
category_set.append('Edges')
value_set.append(n_edges)

# Annotate the values
for i, value in enumerate(values):
offset = size * 0.1 if value < 10 else size * 0.2
ax.text(positions[i], value + offset, str(value), ha='center', va='bottom', fontsize=size)
else:
positions = [f"{network_name} {cat}" for cat in categories]
ax.hlines(y=positions, xmin=0, xmax=values, color=color, linewidth=linewidth, label=network_name)
ax.scatter(values, positions, color=color, s=size ** 2, marker=marker, zorder=3)

# Annotate the values
for i, value in enumerate(values):
offset = size * 0.1 if value < 10 else size * 0.2
ax.text(value + offset, positions[i], str(value), va='center', ha='left', fontsize=size)
labels.append(network_name)
values.append(value_set)
categories.append(category_set)

# Set the axis labels
if orientation == 'vertical':
ax.set_xlabel("Network and Type")
ax.set_ylabel("Count")
else:
ax.set_ylabel("Network and Type")
ax.set_xlabel("Count")

# Set the title depending on the categories
title = "Number of Nodes and Edges"
if show_nodes and not show_edges:
title = "Number of Nodes"
elif show_edges and not show_nodes:
title = "Number of Edges"
ax.set_title(title)
ax.legend()

# Save the plot
if filepath is not None:
plt.savefig(filepath)
lollipop_plot(labels, values, categories, orientation, color_palette, size, linewidth, marker, title, filepath,
render)

if render:
plt.show()

def plot_n_nodes_edges_from_df(
metrics_df: pd.DataFrame,
metrics: List[str],
filepath=None,
render=False,
orientation='vertical',
color_palette='Set2',
size=10,
linewidth=2,
marker='o'
):
"""
Plot the specified metrics from a DataFrame using a lollipop plot.
Args:
metrics_df (pd.DataFrame): DataFrame containing metrics with networks as rows and specified metrics in columns.
metrics (List[str]): List of column names in the DataFrame to plot.
filepath (str): Path to save the plot. Default is None.
render (bool): Whether to display the plot. Default is False.
orientation (str): 'vertical' or 'horizontal'. Default is 'vertical'.
color_palette (str): Matplotlib color palette. Default is 'Set2'.
size (int): Size of the markers. Default is 10.
linewidth (int): Line width of the lollipops. Default is 2.
marker (str): Marker style for the lollipops. Default is 'o'.
"""
if not metrics:
raise ValueError("At least one metric must be specified.")

labels = []
values = []
categories = []

for network_name, row in metrics_df.iterrows():
value_set = []
category_set = []

for metric in metrics:
category_set.append(metric)
value_set.append(row[metric])

labels.append(network_name)
values.append(value_set)
categories.append(category_set)

title = "Metrics"
if len(metrics) == 1:
title = f"Number of {metrics[0]}"
else:
title = f"Number of {' and '.join(metrics)}"

lollipop_plot(labels, values, categories, orientation, color_palette, size, linewidth, marker, title, filepath,
render)


def build_heatmap_with_tree(
jaccard_df: pd.DataFrame,
title: str = "Heatmap (Jaccard Distance)",
palette: str = "viridis",
save: bool = False,
output_dir: str = "."
):
"""
Build a heatmap with hierarchical clustering based on a Jaccard distance matrix.
Args:
jaccard_df (pd.DataFrame): DataFrame containing Jaccard distance matrix.
title (str): Title of the plot.
palette (str): Color palette for the heatmap. Default is "viridis".
save (bool): Whether to save the plot. Default is False.
output_dir (str): Directory to save the plot. Default is ".".
"""
# Convert the square distance matrix to a condensed distance matrix
condensed_dist_matrix = squareform(jaccard_df)

# Perform hierarchical clustering
linked = linkage(condensed_dist_matrix, method='average')

# Create the clustermap
g = sns.clustermap(
jaccard_df,
row_linkage=linked,
col_linkage=linked,
cmap=palette,
figsize=(12, 10),
cbar_kws={'label': 'Jaccard Distance'}
)

# Adjust the position of the title to make sure it appears
# (depends on specific data/plot size)
g.fig.suptitle(title, fontsize=14, x=0.5, y=1.00)

if save:
plt.savefig(f"{output_dir}/heatmap_with_tree.png", bbox_inches='tight')

plt.show()

0 comments on commit cc9a726

Please sign in to comment.