diff --git a/networkcommons/_visual/__init__.py b/networkcommons/_visual/__init__.py index 5fa9849..fb98e96 100644 --- a/networkcommons/_visual/__init__.py +++ b/networkcommons/_visual/__init__.py @@ -17,4 +17,5 @@ Visualization methods for networks and analyses. """ -from . import networkx as networkx +from . import vis_networkx as vis_networkx +from . import yfiles as vis_yfiles \ No newline at end of file diff --git a/networkcommons/_visual/_aux.py b/networkcommons/_visual/_aux.py new file mode 100644 index 0000000..b15d4bf --- /dev/null +++ b/networkcommons/_visual/_aux.py @@ -0,0 +1,11 @@ + +def wrap_node_name(node_name): + if ":" in node_name: + node_name = node_name.replace(":", "_") + if node_name.startswith("COMPLEX"): + # remove the word COMPLEX with a separator (:/-, etc) + return node_name[8:] + else: + return node_name + + diff --git a/networkcommons/_visual/_network_mock.py b/networkcommons/_visual/_network_mock.py new file mode 100644 index 0000000..6e2090d --- /dev/null +++ b/networkcommons/_visual/_network_mock.py @@ -0,0 +1,91 @@ +from graphviz import Digraph +from IPython.display import display +from yfiles_jupyter_graphs import GraphWidget +from typing import Dict, List +import networkx as nx +import pandas as pd +import matplotlib.pyplot as plt +import datetime +import os +from pypath.utils import mapping + +from _aux import wrap_node_name +import yfiles + +class NetworkMock: + + def __init__(self): + self.network = nx.DiGraph() + self.color_map = {} + + def init_from_sif(self, filepath): + """ + Initialize the network from a SIF file + """ + with open(filepath, 'r') as f: + for line in f: + source, target = line.strip().split() + self.network.add_edge(source, target) + return self.network + + def add_nodes(self, nodes): + # add nodes to networkx + self.network.add_nodes_from(nodes) + return self.network + + def set_initial_nodes(self, initial_nodes): + self.update_node_property(initial_nodes, type="initial_node") + + def set_nodes_of_interest(self, nodes): + self.update_node_property(nodes, type="noi", value="1") + + def add_edges(self, edges): + self.network.add_edges_from(edges) + + def update_node_property(self, node, type="color", value="blue"): + # add color to the node in networkx + if type == "color": + self.color_map[node] = value + elif type == "initial_node": + self.network.nodes[node]["initial_node"] = True + return self.color_map + + def update_edge_property(self, edge, type="color", color="blue"): + # add color to the edge in networkx + if type == "color": + self.color_map[edge] = color + return self.color_map + + def draw(self, filepath=None, render=False): + networkx_vis = yfiles.NetworkXVisualizer() + networkx_vis.visualise(render=render) + + def mapping_node_identifier(self, node: str) -> List[str]: + complex_string = None + gene_symbol = None + uniprot = None + + if mapping.id_from_label0(node): + uniprot = mapping.id_from_label0(node) + gene_symbol = mapping.label(uniprot) + elif node.startswith("COMPLEX"): + node = node[8:] #TODO change to wrap + node_list = node.split("_") + translated_node_list = [mapping.label(mapping.id_from_label0(item)) for item in node_list] + complex_string = "COMPLEX:" + "_".join(translated_node_list) + elif mapping.label(node): + gene_symbol = mapping.label(node) + uniprot = mapping.id_from_label0(gene_symbol) + else: + print("Error during translation, check syntax for ", node) + + return [complex_string, gene_symbol, uniprot] + + def convert_edges_into_genesymbol(self, edges: pd.DataFrame) -> pd.DataFrame: + def convert_identifier(x): + identifiers = self.mapping_node_identifier(x) + return identifiers[1] # Using GeneSymbol + edges["source"] = edges["source"].apply(convert_identifier) + edges["target"] = edges["target"].apply(convert_identifier) + return edges + diff --git a/networkcommons/_visual/network_stats.py b/networkcommons/_visual/network_stats.py new file mode 100644 index 0000000..2aabdc0 --- /dev/null +++ b/networkcommons/_visual/network_stats.py @@ -0,0 +1,128 @@ +import pandas as pd +import matplotlib.pyplot as plt +import networkx as nx +from typing import List, Dict + + +def plot_n_nodes_edges( + networks: Dict[str, nx.DiGraph], + filepath=None, + render=False, + orientation='vertical', + color_palette='Set2', + size=10, + linewidth=2, + marker='o', + show_nodes=True, + show_edges=True +): + """ + Plot the number of nodes and edges in the networks using a lollipop plot. + + Args: + networks (Dict[str, nx.DiGraph]): A dictionary of network names and their corresponding graphs. + filepath (str): Path to save the plot. Default is None. + render (bool): Whether to display the plot. Default is False. + orientation (str): 'vertical' or 'horizontal'. Default is 'vertical'. + color_palette (str): Matplotlib color palette. Default is 'Set2'. + size (int): Size of the markers. Default is 10. + linewidth (int): Line width of the lollipops. Default is 2. + marker (str): Marker style for the lollipops. Default is 'o'. + show_nodes (bool): Whether to show nodes count. Default is True. + show_edges (bool): Whether to show edges count. Default is True. + """ + if not show_nodes and not show_edges: + raise ValueError("At least one of 'show_nodes' or 'show_edges' must be True.") + + # Set the color palette + palette = plt.get_cmap(color_palette) + colors = palette.colors if hasattr(palette, 'colors') else palette(range(len(networks))) + + fig, ax = plt.subplots(figsize=(12, 8)) + + for idx, (network_name, network) in enumerate(networks.items()): + # Get the number of nodes and edges + n_nodes = len(network.nodes) + n_edges = len(network.edges) + categories = [] + values = [] + + if show_nodes: + categories.append('Nodes') + values.append(n_nodes) + if show_edges: + categories.append('Edges') + values.append(n_edges) + + color = colors[idx % len(colors)] + + if orientation == 'vertical': + positions = [f"{network_name} {cat}" for cat in categories] + ax.vlines(x=positions, ymin=0, ymax=values, color=color, linewidth=linewidth, label=network_name) + ax.scatter(positions, values, color=color, s=size ** 2, marker=marker, zorder=3) + + # Annotate the values + for i, value in enumerate(values): + offset = size * 0.1 if value < 10 else size * 0.2 + ax.text(positions[i], value + offset, str(value), ha='center', va='bottom', fontsize=size) + else: + positions = [f"{network_name} {cat}" for cat in categories] + ax.hlines(y=positions, xmin=0, xmax=values, color=color, linewidth=linewidth, label=network_name) + ax.scatter(values, positions, color=color, s=size ** 2, marker=marker, zorder=3) + + # Annotate the values + for i, value in enumerate(values): + offset = size * 0.1 if value < 10 else size * 0.2 + ax.text(value + offset, positions[i], str(value), va='center', ha='left', fontsize=size) + + # Set the axis labels + if orientation == 'vertical': + ax.set_xlabel("Network and Type") + ax.set_ylabel("Count") + else: + ax.set_ylabel("Network and Type") + ax.set_xlabel("Count") + + # Set the title depending on the categories + title = "Number of Nodes and Edges" + if show_nodes and not show_edges: + title = "Number of Nodes" + elif show_edges and not show_nodes: + title = "Number of Edges" + ax.set_title(title) + + # Add a legend + ax.legend() + + # Save the plot + if filepath is not None: + plt.savefig(filepath) + + # Render the plot + if render: + plt.show() + + +# Test the function with sample networks +# if __name__ == "__main__": +# # Create sample directed graphs +# G1 = nx.DiGraph() +# G1.add_nodes_from(range(10)) # Adding 10 nodes +# G1.add_edges_from([(i, i + 1) for i in range(9)]) # Adding 9 edges +# +# G2 = nx.DiGraph() +# G2.add_nodes_from(range(15)) # Adding 15 nodes +# G2.add_edges_from([(i, i + 1) for i in range(14)]) # Adding 14 edges +# +# G3 = nx.DiGraph() +# G3.add_nodes_from(range(5)) # Adding 5 nodes +# G3.add_edges_from([(i, (i + 1) % 5) for i in range(5)]) # Adding 5 edges +# +# G4 = nx.DiGraph() +# G4.add_nodes_from(range(20)) # Adding 20 nodes +# G4.add_edges_from([(i, i + 1) for i in range(19)]) # Adding 19 edges +# +# networks = {'Network1': G1, 'Network2': G2, 'Network3': G3, 'Network4': G4} +# +# plot_n_nodes_edges(networks, filepath="nodes_edges_plot.png", render=True, orientation='horizontal', +# color_palette='Set2', size=12, linewidth=2, marker='o', show_nodes=True, show_edges=True) diff --git a/networkcommons/_visual/networkx.py b/networkcommons/_visual/networkx.py deleted file mode 100644 index e397d46..0000000 --- a/networkcommons/_visual/networkx.py +++ /dev/null @@ -1,383 +0,0 @@ -#!/usr/bin/env python - -# -# This file is part of the `networkcommons` Python module -# -# Copyright 2024 -# Heidelberg University Hospital -# -# File author(s): Saez Lab (omnipathdb@gmail.com) -# -# Distributed under the GPLv3 license -# See the file `LICENSE` or read a copy at -# https://www.gnu.org/licenses/gpl-3.0.txt -# - -""" -This module contains functions to visualize networks. -The styles for different types of networks are defined in the get_styles() function. -The set_style_attributes() function is used to set attributes for nodes and edges based on the given styles. -The visualize_network_default() function visualizes the graph with default style. -The visualize_network_sign_consistent() function visualizes the graph considering sign consistency. -The visualize_network() function is the main function to visualize the graph based on the network type. -""" - -import networkx as nx - -from networkcommons._session import _log - -__all__ = [ - 'get_styles', - 'set_style_attributes', - 'merge_styles', - 'visualize_network', - 'visualize_network_default', - 'visualize_network_sign_consistent', - 'visualize_big_graph', - 'visualize_graph_split', -] - - -def get_styles(): - """ - Return a dictionary containing styles for different types of networks. - """ - styles = { - 'default': { - 'nodes': { - 'sources': { - 'shape': 'circle', - 'color': 'steelblue', - 'style': 'filled', - 'fillcolor': 'steelblue', - 'label': '', - 'penwidth': 3 - }, - 'targets': { - 'shape': 'circle', - 'color': 'mediumpurple1', - 'style': 'filled', - 'fillcolor': 'mediumpurple1', - 'label': '', - 'penwidth': 3 - }, - 'other': { - 'shape': 'circle', - 'color': 'gray', - 'style': 'filled', - 'fillcolor': 'gray', - 'label': '' - } - }, - 'edges': { - 'neutral': { - 'color': 'gray30', - 'penwidth': 2 - } - } - }, - 'sign_consistent': { - 'nodes': { - 'sources': { - 'default': { - 'shape': 'circle', - 'style': 'filled', - 'fillcolor': 'steelblue', - 'label': '', - 'penwidth': 3, - 'color': 'steelblue' - }, - 'positive_consistent': { - 'color': 'forestgreen' - }, - 'negative_consistent': { - 'color': 'tomato3' - } - }, - 'targets': { - 'default': { - 'shape': 'circle', - 'style': 'filled', - 'fillcolor': 'mediumpurple1', - 'label': '', - 'penwidth': 3, - 'color': 'mediumpurple1' - }, - 'positive_consistent': { - 'color': 'forestgreen' - }, - 'negative_consistent': { - 'color': 'tomato3' - } - }, - 'other': { - 'default': { - 'shape': 'circle', - 'color': 'gray', - 'style': 'filled', - 'fillcolor': 'gray', - 'label': '' - } - } - }, - 'edges': { - 'positive': { - 'color': 'forestgreen', - 'penwidth': 2 - }, - 'negative': { - 'color': 'tomato3', - 'penwidth': 2 - }, - 'neutral': { - 'color': 'gray30', - 'penwidth': 2 - } - } - }, - # Add more network styles here - } - - return styles - - -def set_style_attributes(item, base_style, condition_style=None): - """ - Set attributes for a graph item (node or edge) based on the given styles. - - Args: - item (node or edge): The item to set attributes for. - base_style (dict): The base style dictionary with default attribute settings. - condition_style (dict, optional): A dictionary of attribute settings for specific conditions. Defaults to None. - """ - for attr, value in base_style.items(): - item.attr[attr] = value - - if condition_style: - for attr, value in condition_style.items(): - item.attr[attr] = value - - -def merge_styles(default_style, custom_style, path=""): - """ - Merge custom styles with default styles to ensure all necessary fields are present. - - Args: - default_style (dict): The default style dictionary. - custom_style (dict): The custom style dictionary. - path (str): The path in the dictionary hierarchy for logging purposes. - - Returns: - dict: The merged style dictionary. - """ - merged_style = default_style.copy() - if custom_style is not None: - for key, value in custom_style.items(): - if isinstance(value, dict) and key in merged_style: - merged_style[key] = merge_styles(merged_style[key], value, f"{path}.{key}" if path else key) - else: - merged_style[key] = value - - # Log missing keys in custom_style - for key in default_style: - if key not in custom_style: - _log(f"Missing key '{path}.{key}' in custom style. Using default value.") - - return merged_style - - -def visualize_network_default(network, source_dict, target_dict, prog='dot', custom_style=None): - """ - Core function to visualize the graph. - - Args: - network (nx.Graph): The network to visualize. - source_dict (dict): A dictionary containing the sources and sign of perturbation. - target_dict (dict): A dictionary containing the targets and sign of measurements. - prog (str, optional): The layout program to use. Defaults to 'dot'. - custom_style (dict, optional): The custom style to apply. If None, the default style is used. - """ - default_style = get_styles()['default'] - style = merge_styles(default_style, custom_style) - - A = nx.nx_agraph.to_agraph(network) - A.graph_attr['ratio'] = '1.2' - - sources = set(source_dict.keys()) - target_dict_flat = {sub_key: sub_value for key, value in target_dict.items() for sub_key, sub_value in value.items()} - targets = set(target_dict_flat.keys()) - - for node in A.nodes(): - n = node.get_name() - if n in sources: - base_style = style['nodes']['sources'] - elif n in targets: - base_style = style['nodes']['targets'] - else: - base_style = style['nodes']['other'] - - set_style_attributes(node, base_style) - - for edge in A.edges(): - edge_style = style['edges']['neutral'] - set_style_attributes(edge, edge_style) - - A.layout(prog=prog) - return A - - -def visualize_network_sign_consistent(network, source_dict, target_dict, prog='dot', custom_style=None): - """ - Visualize the graph considering sign consistency. - - Args: - network (nx.Graph): The network to visualize. - source_dict (dict): A dictionary containing the sources and sign of perturbation. - target_dict (dict): A dictionary containing the targets and sign of measurements. - prog (str, optional): The layout program to use. Defaults to 'dot'. - custom_style (dict, optional): The custom style to apply. Defaults to None. - """ - default_style = get_styles()['sign_consistent'] - style = merge_styles(default_style, custom_style) - - # Call the core visualization function - A = visualize_network_default(network, source_dict, target_dict, prog, style) - - sources = set(source_dict.keys()) - target_dict_flat = {sub_key: sub_value for key, value in target_dict.items() for sub_key, sub_value in value.items()} - targets = set(target_dict_flat.keys()) - - for node in A.nodes(): - n = node.get_name() - condition_style = None - sign_value = target_dict_flat.get(n, 1) - - if n in sources: - nodes_type = "sources" - elif n in targets: - nodes_type = "targets" - - if sign_value > 0: - condition_style = style['nodes'][nodes_type].get('positive_consistent') - elif sign_value < 0: - condition_style = style['nodes'][nodes_type].get('negative_consistent') - - if condition_style: - set_style_attributes(node, {}, condition_style) # Apply condition style without overwriting base style - - for edge in A.edges(): - u, v = edge - edge_data = network.get_edge_data(u, v) - if 'interaction' in edge_data: - edge_data['sign'] = edge_data.pop('interaction') - - if edge_data['sign'] == 1: - edge_style = style['edges']['positive'] - elif edge_data['sign'] == -1: - edge_style = style['edges']['negative'] - else: - edge_style = style['edges']['neutral'] - - set_style_attributes(edge, edge_style) - - return A - - -def visualize_network(network, source_dict, target_dict, prog='dot', network_type='default', custom_style=None): - """ - Main function to visualize the graph based on the network type. - - Args: - network (nx.Graph): The network to visualize. - source_dict (dict): A dictionary containing the sources and sign of perturbation. - target_dict (dict): A dictionary containing the targets and sign of measurements. - prog (str, optional): The layout program to use. Defaults to 'dot'. - network_type (str, optional): The type of visualization to use. Defaults to "default". - custom_style (dict, optional): The custom style to apply. Defaults to None. - """ - if network_type == 'sign_consistent': - return visualize_network_sign_consistent(network, source_dict, target_dict, prog, custom_style) - else: - default_style = get_styles().get(network_type, get_styles()['default']) - return visualize_network_default(network, source_dict, target_dict, prog, custom_style) - - -def visualize_big_graph(): - return NotImplementedError - - -def visualize_graph_split(): - return NotImplementedError - - -#----------------------------- -# Test examples -# import matplotlib.pyplot as plt - -# # Create a sample graph -# G = nx.DiGraph() -# G.add_node("A") -# G.add_node("B") -# G.add_node("C") -# G.add_edge("A", "B") -# G.add_edge("B", "C") -# G.add_edge("C", "A") - -# # Define source and target dictionaries -# source_dict = {"A": 1, "B": -1} -# target_dict = {"C": {"value": 1}} - -# # Basic Example with Default Style -# A = visualize_network(G, source_dict, target_dict, prog='dot', network_type='default') -# A.draw("default_style.png", format='png') -# plt.imshow(plt.imread("default_style.png")) -# plt.axis('off') -# plt.show() - -# # Example with Custom Style -# custom_style = { -# 'nodes': { -# 'sources': { -# 'shape': 'rectangle', -# 'color': 'red', -# 'style': 'filled', -# 'fillcolor': 'red', -# 'penwidth': 2 -# }, -# 'targets': { -# 'shape': 'ellipse', -# 'color': 'blue', -# 'style': 'filled', -# 'fillcolor': 'lightblue', -# 'penwidth': 2 -# }, -# 'other': { -# 'shape': 'diamond', -# 'color': 'green', -# 'style': 'filled', -# 'fillcolor': 'lightgreen', -# 'penwidth': 2 -# } -# }, -# 'edges': { -# 'neutral': { -# 'color': 'black', -# 'penwidth': 1 -# } -# } -# } -# A = visualize_network(G, source_dict, target_dict, prog='dot', network_type='default', custom_style=custom_style) -# A.draw("custom_style.png", format='png') -# plt.imshow(plt.imread("custom_style.png")) -# plt.axis('off') -# plt.show() - -# # Example with Sign Consistent Network -# G["A"]["B"]["interaction"] = 1 -# G["B"]["C"]["interaction"] = -1 -# G["C"]["A"]["interaction"] = 1 -# A = visualize_network(G, source_dict, target_dict, prog='dot', network_type='sign_consistent') -# A.draw("sign_consistent_style.png", format='png') -# plt.imshow(plt.imread("sign_consistent_style.png")) -# plt.axis('off') -# plt.show() diff --git a/networkcommons/_visual/rnaseq.py b/networkcommons/_visual/rnaseq.py new file mode 100644 index 0000000..84aa6dd --- /dev/null +++ b/networkcommons/_visual/rnaseq.py @@ -0,0 +1,295 @@ +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +from sklearn.decomposition import PCA +from typing import List + + +def build_volcano_plot( + data: pd.DataFrame, + log2fc: str, + pval: str, + pval_threshold: float = 0.05, + log2fc_threshold: float = 1, + title: str = "Volcano Plot", + xlabel: str = "log2 Fold Change", + ylabel: str = "-log10(p-value)", + colors: tuple = ("gray", "red"), + alpha: float = 0.7, + size: int = 50, + save: bool = False, + output_dir: str = "." +): + data['-log10(pval)'] = -np.log10(data[pval]) + data['significant'] = (data[pval] < pval_threshold) & (abs(data[log2fc]) >= log2fc_threshold) + + fig, ax = plt.subplots(figsize=(10, 10)) + + ax.scatter( + data.loc[~data['significant'], log2fc], + data.loc[~data['significant'], '-log10(pval)'], + c=colors[0], + alpha=alpha, + s=size, + label='Non-significant' + ) + + ax.scatter( + data.loc[data['significant'], log2fc], + data.loc[data['significant'], '-log10(pval)'], + c=colors[1], + alpha=alpha, + s=size, + label='Significant' + ) + + ax.axhline( + -np.log10(pval_threshold), + color="blue", + linestyle="--" + ) + + ax.axvline( + log2fc_threshold, + color="blue", + linestyle="--" + ) + ax.axvline( + -log2fc_threshold, + color="blue", + linestyle="--" + ) + + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + ax.set_title(title) + ax.legend() + + if save: + plt.savefig(f"{output_dir}/volcano_plot.png") + + plt.show() + + +def build_ma_plot( + data: pd.DataFrame, + log2fc: str, + mean_exp: str, + log2fc_threshold: float = 1, + title: str = "MA Plot", + xlabel: str = "Mean Expression", + ylabel: str = "log2 Fold Change", + colors: tuple = ("gray", "red"), + alpha: float = 0.7, + size: int = 50, + save: bool = False, + output_dir: str = "." +): + data['significant'] = abs(data[log2fc]) >= log2fc_threshold + + fig, ax = plt.subplots(figsize=(10, 10)) + + ax.scatter( + data.loc[~data['significant'], mean_exp], + data.loc[~data['significant'], log2fc], + c=colors[0], + alpha=alpha, + s=size, + label='Non-significant' + ) + + ax.scatter( + data.loc[data['significant'], mean_exp], + data.loc[data['significant'], log2fc], + c=colors[1], + alpha=alpha, + s=size, + label='Significant' + ) + + ax.axhline( + 0, + color="blue", + linestyle="--" + ) + + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + ax.set_title(title) + ax.legend() + + if save: + plt.savefig(f"{output_dir}/ma_plot.png") + + plt.show() + + +def build_pca_plot( + data: pd.DataFrame, + title: str = "PCA Plot", + xlabel: str = "PC1", + ylabel: str = "PC2", + alpha: float = 0.7, + size: int = 50, + save: bool = False, + output_dir: str = "." +): + pca = PCA(n_components=2) + principal_components = pca.fit_transform(data) + pca_df = pd.DataFrame(data=principal_components, columns=['PC1', 'PC2']) + + fig, ax = plt.subplots(figsize=(10, 10)) + + ax.scatter( + pca_df['PC1'], + pca_df['PC2'], + alpha=alpha, + s=size + ) + + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + ax.set_title(title) + + if save: + plt.savefig(f"{output_dir}/pca_plot.png") + + plt.show() + + +def build_heatmap_with_tree( + data: pd.DataFrame, + top_n: int = 50, + value_column: str = 'log2FoldChange_condition_1', + conditions: List[str] = None, + title: str = "Heatmap of Top Differentially Expressed Genes", + save: bool = False, + output_dir: str = "." +): + """ + Build a heatmap with hierarchical clustering for the top differentially expressed genes across multiple conditions. + + Args: + data (pd.DataFrame): DataFrame containing RNA-seq results. + top_n (int): Number of top differentially expressed genes to include in the heatmap. + value_column (str): Column name for the values to rank and select the top genes. + conditions (List[str]): List of condition columns to include in the heatmap. + title (str): Title of the plot. + save (bool): Whether to save the plot. Default is False. + output_dir (str): Directory to save the plot. Default is ".". + """ + if conditions is None: + raise ValueError("Conditions must be provided as a list of column names.") + + # Select top differentially expressed genes + top_genes = data.nlargest(top_n, value_column).index + top_data = data.loc[top_genes, conditions] + + # Create the clustermap + g = sns.clustermap(top_data, cmap="viridis", cbar=True, fmt=".2f", linewidths=.5) + + plt.title(title) + plt.ylabel("Gene") + plt.xlabel("Condition") + + if save: + plt.savefig(f"{output_dir}/heatmap_with_tree.png") + + plt.show() + + +# Test the functions with the generated dataset +# Example condition columns for plotting +# np.random.seed(42) +# +# # Number of genes and conditions +# num_genes = 1000 +# num_conditions = 4 +# +# # Generate gene names +# genes = [f"gene_{i}" for i in range(num_genes)] +# +# # Generate mean expression values for each gene +# mean_expression = np.random.uniform(1, 100, num_genes) +# +# # Generate log2 fold changes and p-values for each condition +# log2_fold_changes = { +# f'log2FoldChange_condition_{i + 1}': np.random.randn(num_genes) for i in range(num_conditions) +# } +# p_values = { +# f'pvalue_condition_{i + 1}': np.random.uniform(0, 1, num_genes) for i in range(num_conditions) +# } +# +# # Combine data into a DataFrame +# data = pd.DataFrame({ +# 'meanExpression': mean_expression, +# **log2_fold_changes, +# **p_values +# }, index=genes) +# +# # Display the first few rows of the generated data +# print(data.head()) +# +# # Save the data to a CSV file for later use (optional) +# data.to_csv('simulated_rnaseq_data.csv') +# +# conditions = [f'log2FC_condition_{i + 1}' for i in range(num_conditions)] +# +# # Plot the Volcano plot for the first condition +# build_volcano_plot( +# data=data, +# log2fc='log2FoldChange_condition_1', +# pval='pvalue_condition_1', +# pval_threshold=0.05, +# log2fc_threshold=1, +# title="Sample Volcano Plot", +# xlabel="log2 Fold Change", +# ylabel="-log10(p-value)", +# colors=("gray", "red"), +# alpha=0.7, +# size=20, +# save=False, +# output_dir="." +# ) +# +# # Plot the MA plot for the first condition +# build_ma_plot( +# data=data, +# log2fc='log2FoldChange_condition_1', +# mean_exp='meanExpression', +# log2fc_threshold=1, +# title="Sample MA Plot", +# xlabel="Mean Expression", +# ylabel="log2 Fold Change", +# colors=("gray", "red"), +# alpha=0.7, +# size=20, +# save=False, +# output_dir="." +# ) +# +# # Plot the PCA plot for all conditions +# pca_conditions = [f'log2FoldChange_condition_{i + 1}' for i in range(num_conditions)] +# build_pca_plot( +# data=data[pca_conditions], +# title="Sample PCA Plot", +# xlabel="PC1", +# ylabel="PC2", +# alpha=0.7, +# size=20, +# save=False, +# output_dir="." +# ) +# +# +# # Plot the Heatmap with hierarchical clustering for the top differentially expressed genes across all conditions +# build_heatmap_with_tree( +# data=data, +# top_n=50, +# value_column='log2FoldChange_condition_1', +# conditions=pca_conditions, +# title="Sample Heatmap of Top Differentially Expressed Genes", +# save=False, +# output_dir="." +# ) diff --git a/networkcommons/_visual/styles.py b/networkcommons/_visual/styles.py new file mode 100644 index 0000000..4c21291 --- /dev/null +++ b/networkcommons/_visual/styles.py @@ -0,0 +1,151 @@ +from networkcommons._session import _log + +def get_styles(): + """ + Return a dictionary containing styles for different types of networks. + """ + styles = { + 'default': { + 'nodes': { + 'sources': { + 'shape': 'circle', + 'color': 'steelblue', + 'style': 'filled', + 'fillcolor': 'steelblue', + 'label': '', + 'penwidth': 3 + }, + 'targets': { + 'shape': 'circle', + 'color': 'mediumpurple1', + 'style': 'filled', + 'fillcolor': 'mediumpurple1', + 'label': '', + 'penwidth': 3 + }, + 'other': { + 'shape': 'circle', + 'color': 'gray', + 'style': 'filled', + 'fillcolor': 'gray', + 'label': '' + } + }, + 'edges': { + 'neutral': { + 'color': 'gray30', + 'penwidth': 2 + } + } + }, + 'sign_consistent': { + 'nodes': { + 'sources': { + 'default': { + 'shape': 'circle', + 'style': 'filled', + 'fillcolor': 'steelblue', + 'label': '', + 'penwidth': 3, + 'color': 'steelblue' + }, + 'positive_consistent': { + 'color': 'forestgreen' + }, + 'negative_consistent': { + 'color': 'tomato3' + } + }, + 'targets': { + 'default': { + 'shape': 'circle', + 'style': 'filled', + 'fillcolor': 'mediumpurple1', + 'label': '', + 'penwidth': 3, + 'color': 'mediumpurple1' + }, + 'positive_consistent': { + 'color': 'forestgreen' + }, + 'negative_consistent': { + 'color': 'tomato3' + } + }, + 'other': { + 'default': { + 'shape': 'circle', + 'color': 'gray', + 'style': 'filled', + 'fillcolor': 'gray', + 'label': '' + } + } + }, + 'edges': { + 'positive': { + 'color': 'forestgreen', + 'penwidth': 2 + }, + 'negative': { + 'color': 'tomato3', + 'penwidth': 2 + }, + 'neutral': { + 'color': 'gray30', + 'penwidth': 2 + } + } + }, + # Add more network styles here + } + + return styles + + +def set_style_attributes(item, base_style, condition_style=None): + """ + Set attributes for a graph item (node or edge) based on the given styles. + + Args: + item (node or edge): The item to set attributes for. + base_style (dict): The base style dictionary with default attribute settings. + condition_style (dict, optional): A dictionary of attribute settings for specific conditions. Defaults to None. + """ + for attr, value in base_style.items(): + item.attr[attr] = value + + if condition_style: + for attr, value in condition_style.items(): + item.attr[attr] = value + + return item + + +def merge_styles(default_style, custom_style, path=""): + """ + Merge custom styles with default styles to ensure all necessary fields are present. + + Args: + default_style (dict): The default style dictionary. + custom_style (dict): The custom style dictionary. + path (str): The path in the dictionary hierarchy for logging purposes. + + Returns: + dict: The merged style dictionary. + """ + merged_style = default_style.copy() + if custom_style is not None: + for key, value in custom_style.items(): + if isinstance(value, dict) and key in merged_style: + merged_style[key] = merge_styles(merged_style[key], value, f"{path}.{key}" if path else key) + else: + merged_style[key] = value + + # Log missing keys in custom_style + for key in default_style: + if key not in custom_style: + _log(f"Missing key '{path}.{key}' in custom style. Using default value.") + + return merged_style + diff --git a/networkcommons/_visual/vis_networkx.py b/networkcommons/_visual/vis_networkx.py new file mode 100644 index 0000000..f8b4a73 --- /dev/null +++ b/networkcommons/_visual/vis_networkx.py @@ -0,0 +1,308 @@ +#!/usr/bin/env python + +# +# This file is part of the `networkcommons` Python module +# +# Copyright 2024 +# Heidelberg University Hospital +# +# File author(s): Saez Lab (omnipathdb@gmail.com) +# +# Distributed under the GPLv3 license +# See the file `LICENSE` or read a copy at +# https://www.gnu.org/licenses/gpl-3.0.txt +# + +""" +This module contains functions to visualize networks. +The styles for different types of networks are defined in the get_styles() function. +The set_style_attributes() function is used to set attributes for nodes and edges based on the given styles. +The visualize_network_default() function visualizes the graph with default style. +The visualize_network_sign_consistent() function visualizes the graph considering sign consistency. +The visualize_network() function is the main function to visualize the graph based on the network type. +""" + +import networkx as nx +import matplotlib.pyplot as plt +from _aux import wrap_node_name + +from networkcommons._session import _log + +#TODO +# __all__ = [ +# 'get_styles', +# 'set_style_attributes', +# 'merge_styles', +# 'visualize_network', +# 'visualize_network_default', +# 'visualize_network_sign_consistent', +# 'visualize_big_graph', +# 'visualize_graph_split', +# ] + + +class NetworkXVisualizer: + + _default_edge_colors = { + 'stimulation': 'green', + 'inhibition': 'red', + 'form complex': 'blue' + } + + _default_node_colors = { + 'initial_node': 'lightyellow', + 'noi': 'lightblue', + 'highlight': 'lightyellow', + 'default': 'lightgray' + } + + def __init__(self, network, color_by="Effect", edge_colors=None): + self.network = network.copy() + self.color_by = color_by + + if edge_colors: + self.edge_colors = edge_colors + else: + self.edge_colors = self._default_edge_colors + + def set_custom_edge_colors(self, custom_edge_colors): + self.edge_colors.update(custom_edge_colors) + + def color_nodes(self): + nodes = self.network.nodes + for node in nodes: + nodes.update_node_property(node, type="color", value="lightgray") + + if nodes[node].get("initial_node"): + self.network.update_node_property(node, + type="color", + value="lightyellow") + def color_edges(self, edge, color): + self.network.update_edge_property(edge, type="color", color=color) + #TODO add arrowheads too + + def visualize_network_default(self, network, source_dict, target_dict, prog='dot', custom_style=None): + """ + Core function to visualize the graph. + + Args: + network (nx.Graph): The network to visualize. + source_dict (dict): A dictionary containing the sources and sign of perturbation. + target_dict (dict): A dictionary containing the targets and sign of measurements. + prog (str, optional): The layout program to use. Defaults to 'dot'. + custom_style (dict, optional): The custom style to apply. If None, the default style is used. + """ + default_style = get_styles()['default'] + style = merge_styles(default_style, custom_style) + + A = nx.nx_agraph.to_agraph(network) + A.graph_attr['ratio'] = '1.2' + + sources = set(source_dict.keys()) + target_dict_flat = {sub_key: sub_value for key, value in target_dict.items() for sub_key, sub_value in + value.items()} + targets = set(target_dict_flat.keys()) + + for node in A.nodes(): + n = node.get_name() + if n in sources: + base_style = style['nodes']['sources'] + elif n in targets: + base_style = style['nodes']['targets'] + else: + base_style = style['nodes']['other'] + + set_style_attributes(node, base_style) + + for edge in A.edges(): + edge_style = style['edges']['neutral'] + set_style_attributes(edge, edge_style) + + A.layout(prog=prog) + return A + + def visualize_network_sign_consistent(network, source_dict, target_dict, prog='dot', custom_style=None): + """ + Visualize the graph considering sign consistency. + + Args: + network (nx.Graph): The network to visualize. + source_dict (dict): A dictionary containing the sources and sign of perturbation. + target_dict (dict): A dictionary containing the targets and sign of measurements. + prog (str, optional): The layout program to use. Defaults to 'dot'. + custom_style (dict, optional): The custom style to apply. Defaults to None. + """ + default_style = get_styles()['sign_consistent'] + style = merge_styles(default_style, custom_style) + + # Call the core visualization function + A = visualize_network_default(network, source_dict, target_dict, prog, style) + + sources = set(source_dict.keys()) + target_dict_flat = {sub_key: sub_value for key, value in target_dict.items() for sub_key, sub_value in + value.items()} + targets = set(target_dict_flat.keys()) + + for node in A.nodes(): + n = node.get_name() + condition_style = None + sign_value = target_dict_flat.get(n, 1) + + if n in sources: + nodes_type = "sources" + elif n in targets: + nodes_type = "targets" + + if sign_value > 0: + condition_style = style['nodes'][nodes_type].get('positive_consistent') + elif sign_value < 0: + condition_style = style['nodes'][nodes_type].get('negative_consistent') + + if condition_style: + set_style_attributes(node, {}, condition_style) # Apply condition style without overwriting base style + + for edge in A.edges(): + u, v = edge + edge_data = network.get_edge_data(u, v) + if 'interaction' in edge_data: + edge_data['sign'] = edge_data.pop('interaction') + + if edge_data['sign'] == 1: + edge_style = style['edges']['positive'] + elif edge_data['sign'] == -1: + edge_style = style['edges']['negative'] + else: + edge_style = style['edges']['neutral'] + + set_style_attributes(edge, edge_style) + + return A + + def visualize_network(self, + network, + source_dict, + target_dict, + prog='dot', + network_type='default', + custom_style=None): + """ + Main function to visualize the graph based on the network type. + + Args: + network (nx.Graph): The network to visualize. + source_dict (dict): A dictionary containing the sources and sign of perturbation. + target_dict (dict): A dictionary containing the targets and sign of measurements. + prog (str, optional): The layout program to use. Defaults to 'dot'. + network_type (str, optional): The type of visualization to use. Defaults to "default". + custom_style (dict, optional): The custom style to apply. Defaults to None. + """ + if network_type == 'sign_consistent': + return visualize_network_sign_consistent(network, source_dict, target_dict, prog, custom_style) + else: + default_style = get_styles().get(network_type, get_styles()['default']) + return visualize_network_default(network, source_dict, target_dict, prog, custom_style) + + + def visualise(self, output_file='network.png', render=False, highlight_nodes=None, style=None): + plt.figure(figsize=(12, 12)) + network = self.network + pos = nx.spring_layout(network) + + node_colors = [network.nodes[node].get('fillcolor', 'lightgray') for node in network.nodes] + edge_colors = [network.edges[edge].get('color', 'black') for edge in network.edges] + + nx.draw(network, pos, node_color=node_colors, edge_color=edge_colors, with_labels=True) + + if highlight_nodes: + if style.get('highlight_color'): + highlight_color = style['highlight_color'] + else: + highlight_color = self._default_node_colors['highlight'] + highlight_nodes = [wrap_node_name(node) for node in highlight_nodes] + nx.draw_networkx_nodes(self.network, pos, nodelist=highlight_nodes, node_color=highlight_color) + + if render: + plt.show() + else: + plt.savefig(output_file) + plt.close() + + def visualize_big_graph(): + return NotImplementedError + + def visualize_graph_split(): + return NotImplementedError + + + +#----------------------------- +# Test examples +# import matplotlib.pyplot as plt + +# # Create a sample graph +# G = nx.DiGraph() +# G.add_node("A") +# G.add_node("B") +# G.add_node("C") +# G.add_edge("A", "B") +# G.add_edge("B", "C") +# G.add_edge("C", "A") + +# # Define source and target dictionaries +# source_dict = {"A": 1, "B": -1} +# target_dict = {"C": {"value": 1}} + +# # Basic Example with Default Style +# A = visualize_network(G, source_dict, target_dict, prog='dot', network_type='default') +# A.draw("default_style.png", format='png') +# plt.imshow(plt.imread("default_style.png")) +# plt.axis('off') +# plt.show() + +# # Example with Custom Style +# custom_style = { +# 'nodes': { +# 'sources': { +# 'shape': 'rectangle', +# 'color': 'red', +# 'style': 'filled', +# 'fillcolor': 'red', +# 'penwidth': 2 +# }, +# 'targets': { +# 'shape': 'ellipse', +# 'color': 'blue', +# 'style': 'filled', +# 'fillcolor': 'lightblue', +# 'penwidth': 2 +# }, +# 'other': { +# 'shape': 'diamond', +# 'color': 'green', +# 'style': 'filled', +# 'fillcolor': 'lightgreen', +# 'penwidth': 2 +# } +# }, +# 'edges': { +# 'neutral': { +# 'color': 'black', +# 'penwidth': 1 +# } +# } +# } +# A = visualize_network(G, source_dict, target_dict, prog='dot', network_type='default', custom_style=custom_style) +# A.draw("custom_style.png", format='png') +# plt.imshow(plt.imread("custom_style.png")) +# plt.axis('off') +# plt.show() + +# # Example with Sign Consistent Network +# G["A"]["B"]["interaction"] = 1 +# G["B"]["C"]["interaction"] = -1 +# G["C"]["A"]["interaction"] = 1 +# A = visualize_network(G, source_dict, target_dict, prog='dot', network_type='sign_consistent') +# A.draw("sign_consistent_style.png", format='png') +# plt.imshow(plt.imread("sign_consistent_style.png")) +# plt.axis('off') +# plt.show() diff --git a/networkcommons/_visual/yfiles.py b/networkcommons/_visual/yfiles.py new file mode 100644 index 0000000..7fa2151 --- /dev/null +++ b/networkcommons/_visual/yfiles.py @@ -0,0 +1,141 @@ +from yfiles_jupyter_graphs import GraphWidget +from typing import Dict +from IPython.display import display +from _aux import wrap_node_name +import pandas as pd + +from yfiles_styles import (get_styles, + apply_node_style, + apply_edge_style, + #set_custom_node_color, + #set_custom_edge_color, + get_edge_color, get_comparison_color) + + +class YFilesVisualizer: + + def __init__(self, network): + self.network = network.copy() + self.styles = get_styles() + + def yfiles_visual(self, graph_layout, directed): + # creating empty object for visualization + w = GraphWidget() + + # filling w with nodes + objects = [] + + for node in self.network.nodes: + node_props = {"label": node} + node = apply_node_style(node_props, self.styles['default']['nodes']) + objects.append({ + "id": node, + "properties": node, + "color": node['color'], + "styles": {"backgroundColor": node['fillcolor']} + }) + + w.nodes = objects + + # filling w with edges + objects = [] + + for edge in self.network.edges: + edge_props = {"color": get_edge_color(edge[2]['effect'], self.styles)} + edge = apply_edge_style(edge_props, self.styles['default']['edges']) + objects.append({ + "id": edge, + "start": edge[0], + "end": edge[1], + "properties": edge + }) + + w.edges = objects + + w.set_edge_color_mapping(self.custom_edge_color_mapping) + w.set_node_styles_mapping(self.custom_node_color_mapping) + w.set_node_scale_factor_mapping(self.custom_factor_mapping) + w.set_node_label_mapping(self.custom_label_styles_mapping) + + w.directed = directed + w.graph_layout = graph_layout + + display(w) + + def vis_comparison(self, int_comparison, node_comparison, graph_layout, directed): + # creating empty object for visualization + w = GraphWidget() + + objects = [] + for idx, item in node_comparison.iterrows(): + node_props = {"label": item["node"], "comparison": item["comparison"]} + node = apply_node_style(node_props, self.styles['default']['nodes']) + node = set_custom_node_color(node, get_comparison_color(item["comparison"], self.styles, 'nodes')) + objects.append({ + "id": item["node"], + "properties": node, + "color": node['color'], + "styles": {"backgroundColor": node['fillcolor']} + }) + w.nodes = objects + + # filling w with edges + objects = [] + for index, row in int_comparison.iterrows(): + edge_props = {"comparison": row["comparison"]} + edge = apply_edge_style(edge_props, self.styles['default']['edges']) + edge = set_custom_edge_color(edge, get_comparison_color(row["comparison"], self.styles, 'edges')) + objects.append({ + "id": row["comparison"], + "start": row["source"], + "end": row["target"], + "properties": edge + }) + w.edges = objects + + w.set_edge_color_mapping(self.custom_edge_color_mapping) + w.set_node_styles_mapping(self.custom_node_color_mapping) + w.set_node_scale_factor_mapping(self.custom_factor_mapping) + w.set_node_label_mapping(self.custom_label_styles_mapping) + + w.directed = directed + w.graph_layout = graph_layout + + display(w) + + @staticmethod + def custom_edge_color_mapping(edge: Dict): + return edge["properties"]["color"] + + @staticmethod + def custom_node_color_mapping(node: Dict): + return {"color": node["color"]} + + @staticmethod + def custom_factor_mapping(node: Dict): + return 5 + + @staticmethod + def custom_label_styles_mapping(node: Dict): + label_style = get_styles()['default']['labels'].copy() + label_style['text'] = node["properties"]["label"] + return label_style + + +# Example usage +# class MockNetwork: +# def __init__(self): +# self.edges = pd.DataFrame({ +# 'source': ['A', 'B', 'C'], +# 'target': ['B', 'C', 'D'], +# 'Effect': ['stimulation', 'inhibition', 'form complex'] +# }) +# self.nodes = pd.DataFrame({ +# 'Genesymbol': ['A', 'B', 'C', 'D'], +# 'Uniprot': ['P1', 'P2', 'P3', 'P4'] +# }) +# self.initial_nodes = ['A', 'B'] +# +# network = MockNetwork() +# visualizer = YFilesVisualizer(network) +# visualizer.yfiles_visual(graph_layout="hierarchic", directed=True) diff --git a/networkcommons/_visual/yfiles_styles.py b/networkcommons/_visual/yfiles_styles.py new file mode 100644 index 0000000..888e6c1 --- /dev/null +++ b/networkcommons/_visual/yfiles_styles.py @@ -0,0 +1,133 @@ +def get_styles(): + """ + Return a dictionary containing styles for different types of networks specific to yFiles visualizations. + """ + styles = { + 'default': { + 'nodes': { + 'shape': 'round-rectangle', + 'color': '#cccccc', + 'style': 'filled', + 'fillcolor': '#cccccc', + 'label': '', + 'penwidth': 1, + 'fontSize': 12, + 'textAlignment': 'center' + }, + 'edges': { + 'color': 'gray', + 'penwidth': 1 + }, + 'labels': { + 'text': '', + 'backgroundColor': None, + 'fontSize': 12, + 'color': '#030200', + 'shape': 'round-rectangle', + 'textAlignment': 'center' + } + }, + 'highlight': { + 'nodes': { + 'shape': 'round-rectangle', + 'color': '#ffcc00', + 'style': 'filled', + 'fillcolor': '#ffcc00', + 'label': '', + 'penwidth': 2, + 'fontSize': 12, + 'textAlignment': 'center' + }, + 'edges': { + 'color': '#ffcc00', + 'penwidth': 2 + }, + 'labels': { + 'text': '', + 'backgroundColor': None, + 'fontSize': 12, + 'color': '#030200', + 'shape': 'round-rectangle', + 'textAlignment': 'center' + } + }, + 'comparison': { + 'nodes': { + 'Unique to Network 1': '#f5f536', + 'Unique to Network 2': '#36f55f', + 'Common': '#3643f5' + }, + 'edges': { + 'Unique to Network 1': '#e3941e', + 'Unique to Network 2': '#36f55f', + 'Common': '#3643f5', + 'Conflicting': '#ffcc00' + } + }, + 'effects': { + 'stimulation': 'green', + 'inhibition': 'red', + 'form complex': 'blue', + 'bimodal': 'purple', + 'undefined': 'gray', + 'default': 'black' + } + # Additional styles can be added here + } + + return styles + + +def apply_node_style(node, style): + """ + Apply the given style to a node. + + Args: + node (dict): The node to style. + style (dict): The style dictionary with node attributes. + """ + for attr, value in style.items(): + node[attr] = value + return node + + +def apply_edge_style(edge, style): + """ + Apply the given style to an edge. + + Args: + edge (dict): The edge to style. + style (dict): The style dictionary with edge attributes. + """ + for attr, value in style.items(): + edge[attr] = value + return edge + + +def get_edge_color(effect, styles): + """ + Get the color for an edge based on its effect. + + Args: + effect (str): The effect type of the edge. + styles (dict): The styles dictionary. + + Returns: + str: The color for the edge. + """ + return styles['effects'].get(effect, styles['effects']['default']) + + +def get_comparison_color(category, styles, element_type='nodes'): + """ + Get the color for nodes or edges based on the comparison category. + + Args: + category (str): The comparison category. + styles (dict): The styles dictionary. + element_type (str): The type of element ('nodes' or 'edges'). + + Returns: + str: The color for the element. + """ + return styles['comparison'][element_type].get(category, styles['default'][element_type]['color']) diff --git a/tests/test_visual_networkx.py b/tests/test_visual_networkx.py index d00152b..97c731e 100644 --- a/tests/test_visual_networkx.py +++ b/tests/test_visual_networkx.py @@ -2,7 +2,7 @@ import networkx as nx -import networkcommons._visual.networkx as _vis +import networkcommons._visual.vis_networkx as _vis class TestVisualizeNetwork: