From 09ba3655012f12576c0d45cf7eafb8ed773755b4 Mon Sep 17 00:00:00 2001 From: Olga Ivanova Date: Mon, 1 Jul 2024 12:36:11 +0200 Subject: [PATCH 1/4] added the vis module from Neco, integrated/refactored with networkcommons code; more structure added to the vis --- networkcommons/_visual/__init__.py | 1 + networkcommons/_visual/_aux.py | 11 + networkcommons/_visual/_network_mock.py | 91 +++++ networkcommons/_visual/networkx.py | 445 ++++++++++-------------- networkcommons/_visual/styles.py | 146 ++++++++ networkcommons/_visual/yfiles.py | 187 ++++++++++ 6 files changed, 621 insertions(+), 260 deletions(-) create mode 100644 networkcommons/_visual/_aux.py create mode 100644 networkcommons/_visual/_network_mock.py create mode 100644 networkcommons/_visual/styles.py create mode 100644 networkcommons/_visual/yfiles.py diff --git a/networkcommons/_visual/__init__.py b/networkcommons/_visual/__init__.py index 5fa9849..36120f4 100644 --- a/networkcommons/_visual/__init__.py +++ b/networkcommons/_visual/__init__.py @@ -18,3 +18,4 @@ """ from . import networkx as networkx +from . import yfiles as yfiles diff --git a/networkcommons/_visual/_aux.py b/networkcommons/_visual/_aux.py new file mode 100644 index 0000000..b15d4bf --- /dev/null +++ b/networkcommons/_visual/_aux.py @@ -0,0 +1,11 @@ + +def wrap_node_name(node_name): + if ":" in node_name: + node_name = node_name.replace(":", "_") + if node_name.startswith("COMPLEX"): + # remove the word COMPLEX with a separator (:/-, etc) + return node_name[8:] + else: + return node_name + + diff --git a/networkcommons/_visual/_network_mock.py b/networkcommons/_visual/_network_mock.py new file mode 100644 index 0000000..6e2090d --- /dev/null +++ b/networkcommons/_visual/_network_mock.py @@ -0,0 +1,91 @@ +from graphviz import Digraph +from IPython.display import display +from yfiles_jupyter_graphs import GraphWidget +from typing import Dict, List +import networkx as nx +import pandas as pd +import matplotlib.pyplot as plt +import datetime +import os +from pypath.utils import mapping + +from _aux import wrap_node_name +import yfiles + +class NetworkMock: + + def __init__(self): + self.network = nx.DiGraph() + self.color_map = {} + + def init_from_sif(self, filepath): + """ + Initialize the network from a SIF file + """ + with open(filepath, 'r') as f: + for line in f: + source, target = line.strip().split() + self.network.add_edge(source, target) + return self.network + + def add_nodes(self, nodes): + # add nodes to networkx + self.network.add_nodes_from(nodes) + return self.network + + def set_initial_nodes(self, initial_nodes): + self.update_node_property(initial_nodes, type="initial_node") + + def set_nodes_of_interest(self, nodes): + self.update_node_property(nodes, type="noi", value="1") + + def add_edges(self, edges): + self.network.add_edges_from(edges) + + def update_node_property(self, node, type="color", value="blue"): + # add color to the node in networkx + if type == "color": + self.color_map[node] = value + elif type == "initial_node": + self.network.nodes[node]["initial_node"] = True + return self.color_map + + def update_edge_property(self, edge, type="color", color="blue"): + # add color to the edge in networkx + if type == "color": + self.color_map[edge] = color + return self.color_map + + def draw(self, filepath=None, render=False): + networkx_vis = yfiles.NetworkXVisualizer() + networkx_vis.visualise(render=render) + + def mapping_node_identifier(self, node: str) -> List[str]: + complex_string = None + gene_symbol = None + uniprot = None + + if mapping.id_from_label0(node): + uniprot = mapping.id_from_label0(node) + gene_symbol = mapping.label(uniprot) + elif node.startswith("COMPLEX"): + node = node[8:] #TODO change to wrap + node_list = node.split("_") + translated_node_list = [mapping.label(mapping.id_from_label0(item)) for item in node_list] + complex_string = "COMPLEX:" + "_".join(translated_node_list) + elif mapping.label(node): + gene_symbol = mapping.label(node) + uniprot = mapping.id_from_label0(gene_symbol) + else: + print("Error during translation, check syntax for ", node) + + return [complex_string, gene_symbol, uniprot] + + def convert_edges_into_genesymbol(self, edges: pd.DataFrame) -> pd.DataFrame: + def convert_identifier(x): + identifiers = self.mapping_node_identifier(x) + return identifiers[1] # Using GeneSymbol + edges["source"] = edges["source"].apply(convert_identifier) + edges["target"] = edges["target"].apply(convert_identifier) + return edges + diff --git a/networkcommons/_visual/networkx.py b/networkcommons/_visual/networkx.py index e397d46..f8b4a73 100644 --- a/networkcommons/_visual/networkx.py +++ b/networkcommons/_visual/networkx.py @@ -23,291 +23,216 @@ """ import networkx as nx +import matplotlib.pyplot as plt +from _aux import wrap_node_name from networkcommons._session import _log -__all__ = [ - 'get_styles', - 'set_style_attributes', - 'merge_styles', - 'visualize_network', - 'visualize_network_default', - 'visualize_network_sign_consistent', - 'visualize_big_graph', - 'visualize_graph_split', -] - - -def get_styles(): - """ - Return a dictionary containing styles for different types of networks. - """ - styles = { - 'default': { - 'nodes': { - 'sources': { - 'shape': 'circle', - 'color': 'steelblue', - 'style': 'filled', - 'fillcolor': 'steelblue', - 'label': '', - 'penwidth': 3 - }, - 'targets': { - 'shape': 'circle', - 'color': 'mediumpurple1', - 'style': 'filled', - 'fillcolor': 'mediumpurple1', - 'label': '', - 'penwidth': 3 - }, - 'other': { - 'shape': 'circle', - 'color': 'gray', - 'style': 'filled', - 'fillcolor': 'gray', - 'label': '' - } - }, - 'edges': { - 'neutral': { - 'color': 'gray30', - 'penwidth': 2 - } - } - }, - 'sign_consistent': { - 'nodes': { - 'sources': { - 'default': { - 'shape': 'circle', - 'style': 'filled', - 'fillcolor': 'steelblue', - 'label': '', - 'penwidth': 3, - 'color': 'steelblue' - }, - 'positive_consistent': { - 'color': 'forestgreen' - }, - 'negative_consistent': { - 'color': 'tomato3' - } - }, - 'targets': { - 'default': { - 'shape': 'circle', - 'style': 'filled', - 'fillcolor': 'mediumpurple1', - 'label': '', - 'penwidth': 3, - 'color': 'mediumpurple1' - }, - 'positive_consistent': { - 'color': 'forestgreen' - }, - 'negative_consistent': { - 'color': 'tomato3' - } - }, - 'other': { - 'default': { - 'shape': 'circle', - 'color': 'gray', - 'style': 'filled', - 'fillcolor': 'gray', - 'label': '' - } - } - }, - 'edges': { - 'positive': { - 'color': 'forestgreen', - 'penwidth': 2 - }, - 'negative': { - 'color': 'tomato3', - 'penwidth': 2 - }, - 'neutral': { - 'color': 'gray30', - 'penwidth': 2 - } - } - }, - # Add more network styles here +#TODO +# __all__ = [ +# 'get_styles', +# 'set_style_attributes', +# 'merge_styles', +# 'visualize_network', +# 'visualize_network_default', +# 'visualize_network_sign_consistent', +# 'visualize_big_graph', +# 'visualize_graph_split', +# ] + + +class NetworkXVisualizer: + + _default_edge_colors = { + 'stimulation': 'green', + 'inhibition': 'red', + 'form complex': 'blue' } - return styles - - -def set_style_attributes(item, base_style, condition_style=None): - """ - Set attributes for a graph item (node or edge) based on the given styles. - - Args: - item (node or edge): The item to set attributes for. - base_style (dict): The base style dictionary with default attribute settings. - condition_style (dict, optional): A dictionary of attribute settings for specific conditions. Defaults to None. - """ - for attr, value in base_style.items(): - item.attr[attr] = value - - if condition_style: - for attr, value in condition_style.items(): - item.attr[attr] = value - - -def merge_styles(default_style, custom_style, path=""): - """ - Merge custom styles with default styles to ensure all necessary fields are present. + _default_node_colors = { + 'initial_node': 'lightyellow', + 'noi': 'lightblue', + 'highlight': 'lightyellow', + 'default': 'lightgray' + } - Args: - default_style (dict): The default style dictionary. - custom_style (dict): The custom style dictionary. - path (str): The path in the dictionary hierarchy for logging purposes. + def __init__(self, network, color_by="Effect", edge_colors=None): + self.network = network.copy() + self.color_by = color_by - Returns: - dict: The merged style dictionary. - """ - merged_style = default_style.copy() - if custom_style is not None: - for key, value in custom_style.items(): - if isinstance(value, dict) and key in merged_style: - merged_style[key] = merge_styles(merged_style[key], value, f"{path}.{key}" if path else key) - else: - merged_style[key] = value - - # Log missing keys in custom_style - for key in default_style: - if key not in custom_style: - _log(f"Missing key '{path}.{key}' in custom style. Using default value.") - - return merged_style - - -def visualize_network_default(network, source_dict, target_dict, prog='dot', custom_style=None): - """ - Core function to visualize the graph. - - Args: - network (nx.Graph): The network to visualize. - source_dict (dict): A dictionary containing the sources and sign of perturbation. - target_dict (dict): A dictionary containing the targets and sign of measurements. - prog (str, optional): The layout program to use. Defaults to 'dot'. - custom_style (dict, optional): The custom style to apply. If None, the default style is used. - """ - default_style = get_styles()['default'] - style = merge_styles(default_style, custom_style) - - A = nx.nx_agraph.to_agraph(network) - A.graph_attr['ratio'] = '1.2' - - sources = set(source_dict.keys()) - target_dict_flat = {sub_key: sub_value for key, value in target_dict.items() for sub_key, sub_value in value.items()} - targets = set(target_dict_flat.keys()) - - for node in A.nodes(): - n = node.get_name() - if n in sources: - base_style = style['nodes']['sources'] - elif n in targets: - base_style = style['nodes']['targets'] + if edge_colors: + self.edge_colors = edge_colors else: - base_style = style['nodes']['other'] - - set_style_attributes(node, base_style) - - for edge in A.edges(): - edge_style = style['edges']['neutral'] - set_style_attributes(edge, edge_style) - - A.layout(prog=prog) - return A - - -def visualize_network_sign_consistent(network, source_dict, target_dict, prog='dot', custom_style=None): - """ - Visualize the graph considering sign consistency. - - Args: - network (nx.Graph): The network to visualize. - source_dict (dict): A dictionary containing the sources and sign of perturbation. - target_dict (dict): A dictionary containing the targets and sign of measurements. - prog (str, optional): The layout program to use. Defaults to 'dot'. - custom_style (dict, optional): The custom style to apply. Defaults to None. - """ - default_style = get_styles()['sign_consistent'] - style = merge_styles(default_style, custom_style) - - # Call the core visualization function - A = visualize_network_default(network, source_dict, target_dict, prog, style) - - sources = set(source_dict.keys()) - target_dict_flat = {sub_key: sub_value for key, value in target_dict.items() for sub_key, sub_value in value.items()} - targets = set(target_dict_flat.keys()) - - for node in A.nodes(): - n = node.get_name() - condition_style = None - sign_value = target_dict_flat.get(n, 1) + self.edge_colors = self._default_edge_colors + + def set_custom_edge_colors(self, custom_edge_colors): + self.edge_colors.update(custom_edge_colors) + + def color_nodes(self): + nodes = self.network.nodes + for node in nodes: + nodes.update_node_property(node, type="color", value="lightgray") + + if nodes[node].get("initial_node"): + self.network.update_node_property(node, + type="color", + value="lightyellow") + def color_edges(self, edge, color): + self.network.update_edge_property(edge, type="color", color=color) + #TODO add arrowheads too + + def visualize_network_default(self, network, source_dict, target_dict, prog='dot', custom_style=None): + """ + Core function to visualize the graph. + + Args: + network (nx.Graph): The network to visualize. + source_dict (dict): A dictionary containing the sources and sign of perturbation. + target_dict (dict): A dictionary containing the targets and sign of measurements. + prog (str, optional): The layout program to use. Defaults to 'dot'. + custom_style (dict, optional): The custom style to apply. If None, the default style is used. + """ + default_style = get_styles()['default'] + style = merge_styles(default_style, custom_style) + + A = nx.nx_agraph.to_agraph(network) + A.graph_attr['ratio'] = '1.2' + + sources = set(source_dict.keys()) + target_dict_flat = {sub_key: sub_value for key, value in target_dict.items() for sub_key, sub_value in + value.items()} + targets = set(target_dict_flat.keys()) + + for node in A.nodes(): + n = node.get_name() + if n in sources: + base_style = style['nodes']['sources'] + elif n in targets: + base_style = style['nodes']['targets'] + else: + base_style = style['nodes']['other'] - if n in sources: - nodes_type = "sources" - elif n in targets: - nodes_type = "targets" + set_style_attributes(node, base_style) - if sign_value > 0: + for edge in A.edges(): + edge_style = style['edges']['neutral'] + set_style_attributes(edge, edge_style) + + A.layout(prog=prog) + return A + + def visualize_network_sign_consistent(network, source_dict, target_dict, prog='dot', custom_style=None): + """ + Visualize the graph considering sign consistency. + + Args: + network (nx.Graph): The network to visualize. + source_dict (dict): A dictionary containing the sources and sign of perturbation. + target_dict (dict): A dictionary containing the targets and sign of measurements. + prog (str, optional): The layout program to use. Defaults to 'dot'. + custom_style (dict, optional): The custom style to apply. Defaults to None. + """ + default_style = get_styles()['sign_consistent'] + style = merge_styles(default_style, custom_style) + + # Call the core visualization function + A = visualize_network_default(network, source_dict, target_dict, prog, style) + + sources = set(source_dict.keys()) + target_dict_flat = {sub_key: sub_value for key, value in target_dict.items() for sub_key, sub_value in + value.items()} + targets = set(target_dict_flat.keys()) + + for node in A.nodes(): + n = node.get_name() + condition_style = None + sign_value = target_dict_flat.get(n, 1) + + if n in sources: + nodes_type = "sources" + elif n in targets: + nodes_type = "targets" + + if sign_value > 0: condition_style = style['nodes'][nodes_type].get('positive_consistent') - elif sign_value < 0: + elif sign_value < 0: condition_style = style['nodes'][nodes_type].get('negative_consistent') - if condition_style: - set_style_attributes(node, {}, condition_style) # Apply condition style without overwriting base style + if condition_style: + set_style_attributes(node, {}, condition_style) # Apply condition style without overwriting base style - for edge in A.edges(): - u, v = edge - edge_data = network.get_edge_data(u, v) - if 'interaction' in edge_data: - edge_data['sign'] = edge_data.pop('interaction') + for edge in A.edges(): + u, v = edge + edge_data = network.get_edge_data(u, v) + if 'interaction' in edge_data: + edge_data['sign'] = edge_data.pop('interaction') - if edge_data['sign'] == 1: - edge_style = style['edges']['positive'] - elif edge_data['sign'] == -1: - edge_style = style['edges']['negative'] + if edge_data['sign'] == 1: + edge_style = style['edges']['positive'] + elif edge_data['sign'] == -1: + edge_style = style['edges']['negative'] + else: + edge_style = style['edges']['neutral'] + + set_style_attributes(edge, edge_style) + + return A + + def visualize_network(self, + network, + source_dict, + target_dict, + prog='dot', + network_type='default', + custom_style=None): + """ + Main function to visualize the graph based on the network type. + + Args: + network (nx.Graph): The network to visualize. + source_dict (dict): A dictionary containing the sources and sign of perturbation. + target_dict (dict): A dictionary containing the targets and sign of measurements. + prog (str, optional): The layout program to use. Defaults to 'dot'. + network_type (str, optional): The type of visualization to use. Defaults to "default". + custom_style (dict, optional): The custom style to apply. Defaults to None. + """ + if network_type == 'sign_consistent': + return visualize_network_sign_consistent(network, source_dict, target_dict, prog, custom_style) else: - edge_style = style['edges']['neutral'] + default_style = get_styles().get(network_type, get_styles()['default']) + return visualize_network_default(network, source_dict, target_dict, prog, custom_style) - set_style_attributes(edge, edge_style) - return A + def visualise(self, output_file='network.png', render=False, highlight_nodes=None, style=None): + plt.figure(figsize=(12, 12)) + network = self.network + pos = nx.spring_layout(network) + node_colors = [network.nodes[node].get('fillcolor', 'lightgray') for node in network.nodes] + edge_colors = [network.edges[edge].get('color', 'black') for edge in network.edges] -def visualize_network(network, source_dict, target_dict, prog='dot', network_type='default', custom_style=None): - """ - Main function to visualize the graph based on the network type. + nx.draw(network, pos, node_color=node_colors, edge_color=edge_colors, with_labels=True) - Args: - network (nx.Graph): The network to visualize. - source_dict (dict): A dictionary containing the sources and sign of perturbation. - target_dict (dict): A dictionary containing the targets and sign of measurements. - prog (str, optional): The layout program to use. Defaults to 'dot'. - network_type (str, optional): The type of visualization to use. Defaults to "default". - custom_style (dict, optional): The custom style to apply. Defaults to None. - """ - if network_type == 'sign_consistent': - return visualize_network_sign_consistent(network, source_dict, target_dict, prog, custom_style) - else: - default_style = get_styles().get(network_type, get_styles()['default']) - return visualize_network_default(network, source_dict, target_dict, prog, custom_style) + if highlight_nodes: + if style.get('highlight_color'): + highlight_color = style['highlight_color'] + else: + highlight_color = self._default_node_colors['highlight'] + highlight_nodes = [wrap_node_name(node) for node in highlight_nodes] + nx.draw_networkx_nodes(self.network, pos, nodelist=highlight_nodes, node_color=highlight_color) + if render: + plt.show() + else: + plt.savefig(output_file) + plt.close() -def visualize_big_graph(): - return NotImplementedError + def visualize_big_graph(): + return NotImplementedError + def visualize_graph_split(): + return NotImplementedError -def visualize_graph_split(): - return NotImplementedError #----------------------------- diff --git a/networkcommons/_visual/styles.py b/networkcommons/_visual/styles.py new file mode 100644 index 0000000..9b08ae0 --- /dev/null +++ b/networkcommons/_visual/styles.py @@ -0,0 +1,146 @@ +def get_styles(): + """ + Return a dictionary containing styles for different types of networks. + """ + styles = { + 'default': { + 'nodes': { + 'sources': { + 'shape': 'circle', + 'color': 'steelblue', + 'style': 'filled', + 'fillcolor': 'steelblue', + 'label': '', + 'penwidth': 3 + }, + 'targets': { + 'shape': 'circle', + 'color': 'mediumpurple1', + 'style': 'filled', + 'fillcolor': 'mediumpurple1', + 'label': '', + 'penwidth': 3 + }, + 'other': { + 'shape': 'circle', + 'color': 'gray', + 'style': 'filled', + 'fillcolor': 'gray', + 'label': '' + } + }, + 'edges': { + 'neutral': { + 'color': 'gray30', + 'penwidth': 2 + } + } + }, + 'sign_consistent': { + 'nodes': { + 'sources': { + 'default': { + 'shape': 'circle', + 'style': 'filled', + 'fillcolor': 'steelblue', + 'label': '', + 'penwidth': 3, + 'color': 'steelblue' + }, + 'positive_consistent': { + 'color': 'forestgreen' + }, + 'negative_consistent': { + 'color': 'tomato3' + } + }, + 'targets': { + 'default': { + 'shape': 'circle', + 'style': 'filled', + 'fillcolor': 'mediumpurple1', + 'label': '', + 'penwidth': 3, + 'color': 'mediumpurple1' + }, + 'positive_consistent': { + 'color': 'forestgreen' + }, + 'negative_consistent': { + 'color': 'tomato3' + } + }, + 'other': { + 'default': { + 'shape': 'circle', + 'color': 'gray', + 'style': 'filled', + 'fillcolor': 'gray', + 'label': '' + } + } + }, + 'edges': { + 'positive': { + 'color': 'forestgreen', + 'penwidth': 2 + }, + 'negative': { + 'color': 'tomato3', + 'penwidth': 2 + }, + 'neutral': { + 'color': 'gray30', + 'penwidth': 2 + } + } + }, + # Add more network styles here + } + + return styles + +def set_style_attributes(item, base_style, condition_style=None): + """ + Set attributes for a graph item (node or edge) based on the given styles. + + Args: + item (node or edge): The item to set attributes for. + base_style (dict): The base style dictionary with default attribute settings. + condition_style (dict, optional): A dictionary of attribute settings for specific conditions. Defaults to None. + """ + for attr, value in base_style.items(): + item.attr[attr] = value + + if condition_style: + for attr, value in condition_style.items(): + item.attr[attr] = value + + +def merge_styles(default_style, custom_style, path=""): + """ + Merge custom styles with default styles to ensure all necessary fields are present. + + Args: + default_style (dict): The default style dictionary. + custom_style (dict): The custom style dictionary. + path (str): The path in the dictionary hierarchy for logging purposes. + + Returns: + dict: The merged style dictionary. + """ + merged_style = default_style.copy() + if custom_style is not None: + for key, value in custom_style.items(): + if isinstance(value, dict) and key in merged_style: + merged_style[key] = merge_styles(merged_style[key], value, f"{path}.{key}" if path else key) + else: + merged_style[key] = value + + # Log missing keys in custom_style + for key in default_style: + if key not in custom_style: + _log(f"Missing key '{path}.{key}' in custom style. Using default value.") + + return merged_style + diff --git a/networkcommons/_visual/yfiles.py b/networkcommons/_visual/yfiles.py new file mode 100644 index 0000000..6cb9fc2 --- /dev/null +++ b/networkcommons/_visual/yfiles.py @@ -0,0 +1,187 @@ +import pandas as pd +import networkx as nx +import matplotlib.pyplot as plt +from yfiles_jupyter_graphs import GraphWidget +from IPython.display import display +from typing import Dict, List +from pypath.utils import mapping +from _aux import wrap_node_name + + + + +class YFilesVisualizer: + + def __init__(self, network): + self.network = network.copy() + + def yfiles_visual( + self, + graph_layout, + directed, + ): + # creating empty object for visualization + w = GraphWidget() + + # filling w with nodes + objects = [] + for idx, item in self.dataframe_nodes.iterrows(): + obj = { + "id": self.dataframe_nodes["Uniprot"].loc[idx], + "properties": {"label": self.dataframe_nodes["Genesymbol"].loc[idx]}, + "color": "#ffffff", + "styles": {"backgroundColor": "#ffffff"} + } + objects.append(obj) + w.nodes = objects + + # filling w with edges + objects = [] + for index, row in self.dataframe_edges.iterrows(): + obj = { + "id": self.dataframe_edges["Effect"].loc[index], + "start": self.dataframe_edges["source"].loc[index], + "end": self.dataframe_edges["target"].loc[index], + "properties": {"references": self.dataframe_edges["References"].loc[index]}} + objects.append(obj) + w.edges = objects + + def custom_edge_color_mapping(edge: Dict): + """let the edge be red if the interaction is an inhibition, else green""" + return ("#fa1505" if edge['id'] == "inhibition" else "#05e60c") + + w.set_edge_color_mapping(custom_edge_color_mapping) + + def custom_node_color_mapping(node: Dict): + return {"color": "#ffffff"} + + w.set_node_styles_mapping(custom_node_color_mapping) + + def custom_factor_mapping(node: Dict): + """choose random factor""" + return 5 + + w.set_node_scale_factor_mapping(custom_factor_mapping) + + def custom_label_styles_mapping(node: Dict): + """let the label be the negated purple big index""" + return { + 'text': node["properties"]["label"], + 'backgroundColor': None, + 'fontSize': 40, + 'color': '#030200', + 'shape': 'round-rectangle', + 'textAlignment': 'center' + } + + w.set_node_label_mapping(custom_label_styles_mapping) + + w.directed = directed + w.graph_layout = graph_layout + + display(w) + + def vis_comparison( + self, + int_comparison, + node_comparison, + graph_layout, + directed, + ): + # creating empty object for visualization + w = GraphWidget() + + objects = [] + for idx, item in node_comparison.iterrows(): + obj = { + "id": node_comparison["node"].loc[idx], + "properties": {"label": node_comparison["node"].loc[idx], + "comparison": node_comparison["comparison"].loc[idx], }, + "color": "#ffffff", + # "styles":{"backgroundColor":"#ffffff"} + } + objects.append(obj) + w.nodes = objects + + # filling w with edges + objects = [] + for index, row in int_comparison.iterrows(): + obj = { + "id": int_comparison["comparison"].loc[index], + "properties": { + "comparison": int_comparison["comparison"].loc[index]}, + "start": int_comparison["source"].loc[index], + "end": int_comparison["target"].loc[index] + } + objects.append(obj) + w.edges = objects + + def custom_node_color_mapping(node: Dict): + if node['properties']['comparison'] == "Unique to Network 1": + return {"color": "#f5f536"} + elif node['properties']['comparison'] == "Unique to Network 2": + return {"color": "#36f55f"} + elif node['properties']['comparison'] == "Common": + return {"color": "#3643f5"} + + w.set_node_styles_mapping(custom_node_color_mapping) + + def custom_factor_mapping(node: Dict): + """choose random factor""" + return 5 + + w.set_node_scale_factor_mapping(custom_factor_mapping) + + def custom_label_styles_mapping(node: Dict): + """let the label be the negated purple big index""" + return { + 'text': node["id"], + 'backgroundColor': None, + 'fontSize': 20, + 'color': '#030200', + 'position': 'center', + 'maximumWidth': 130, + 'wrapping': 'word', + 'textAlignment': 'center' + } + + w.set_node_label_mapping(custom_label_styles_mapping) + + def custom_edge_color_mapping(edge: Dict): + if edge['id'] == "Unique to Network 1": + return "#e3941e" + elif edge['id'] == "Unique to Network 2": + return "#36f55f" + elif edge['id'] == "Common": + return "#3643f5" + elif edge['id'] == "Conflicting": + return "#ffcc00" + + w.set_edge_color_mapping(custom_edge_color_mapping) + + w.directed = directed + w.graph_layout = graph_layout + + display(w) + + +# Example usage +# network is assumed to be a custom object with edges, nodes, and initial_nodes attributes. +# Here is a mock example of how to create such an object. You should replace it with your actual network data. + +# class MockNetwork: +# def __init__(self): +# self.edges = pd.DataFrame({ +# 'source': ['A', 'B', 'C'], +# 'target': ['B', 'C', 'D'], +# 'Effect': ['stimulation', 'inhibition', 'form complex'] +# }) +# self.nodes = pd.DataFrame({ +# 'Genesymbol': ['A', 'B', 'C', 'D'], +# 'Uniprot': ['P1', 'P2', 'P3', 'P4'] +# }) +# self.initial_nodes = ['A', 'B'] +# +# network = MockNetwork() +# visualizer = NetworkXVisualizer(network) +# visualizer.render(view=True) From 3b35adc94d7195e7c1da0e19cc83c31a89079834 Mon Sep 17 00:00:00 2001 From: Olga Ivanova Date: Tue, 2 Jul 2024 10:53:40 +0200 Subject: [PATCH 2/4] styles updated: separating styles for networkx and yfiles, yfiles can be set through a dict --- networkcommons/_visual/styles.py | 5 + networkcommons/_visual/yfiles.py | 203 +++++++++--------------- networkcommons/_visual/yfiles_styles.py | 133 ++++++++++++++++ 3 files changed, 216 insertions(+), 125 deletions(-) create mode 100644 networkcommons/_visual/yfiles_styles.py diff --git a/networkcommons/_visual/styles.py b/networkcommons/_visual/styles.py index 9b08ae0..4c21291 100644 --- a/networkcommons/_visual/styles.py +++ b/networkcommons/_visual/styles.py @@ -1,3 +1,5 @@ +from networkcommons._session import _log + def get_styles(): """ Return a dictionary containing styles for different types of networks. @@ -100,6 +102,7 @@ def get_styles(): return styles + def set_style_attributes(item, base_style, condition_style=None): """ Set attributes for a graph item (node or edge) based on the given styles. @@ -116,6 +119,8 @@ def set_style_attributes(item, base_style, condition_style=None): for attr, value in condition_style.items(): item.attr[attr] = value + return item + def merge_styles(default_style, custom_style, path=""): """ diff --git a/networkcommons/_visual/yfiles.py b/networkcommons/_visual/yfiles.py index 6cb9fc2..65f5c57 100644 --- a/networkcommons/_visual/yfiles.py +++ b/networkcommons/_visual/yfiles.py @@ -1,174 +1,127 @@ -import pandas as pd -import networkx as nx -import matplotlib.pyplot as plt from yfiles_jupyter_graphs import GraphWidget +from typing import Dict from IPython.display import display -from typing import Dict, List -from pypath.utils import mapping from _aux import wrap_node_name +import pandas as pd - - +from yfiles_styles import (get_styles, + apply_node_style, + apply_edge_style, + #set_custom_node_color, + #set_custom_edge_color, + get_edge_color, get_comparison_color) class YFilesVisualizer: def __init__(self, network): self.network = network.copy() + self.styles = get_styles() - def yfiles_visual( - self, - graph_layout, - directed, - ): + def yfiles_visual(self, graph_layout, directed): # creating empty object for visualization w = GraphWidget() # filling w with nodes objects = [] - for idx, item in self.dataframe_nodes.iterrows(): - obj = { - "id": self.dataframe_nodes["Uniprot"].loc[idx], - "properties": {"label": self.dataframe_nodes["Genesymbol"].loc[idx]}, - "color": "#ffffff", - "styles": {"backgroundColor": "#ffffff"} - } - objects.append(obj) + + for node in self.network.nodes: + node_props = {"label": node} + node = apply_node_style(node_props, self.styles['default']['nodes']) + objects.append({ + "id": node, + "properties": node, + "color": node['color'], + "styles": {"backgroundColor": node['fillcolor']} + }) + w.nodes = objects # filling w with edges objects = [] - for index, row in self.dataframe_edges.iterrows(): - obj = { - "id": self.dataframe_edges["Effect"].loc[index], - "start": self.dataframe_edges["source"].loc[index], - "end": self.dataframe_edges["target"].loc[index], - "properties": {"references": self.dataframe_edges["References"].loc[index]}} - objects.append(obj) - w.edges = objects - - def custom_edge_color_mapping(edge: Dict): - """let the edge be red if the interaction is an inhibition, else green""" - return ("#fa1505" if edge['id'] == "inhibition" else "#05e60c") - - w.set_edge_color_mapping(custom_edge_color_mapping) - def custom_node_color_mapping(node: Dict): - return {"color": "#ffffff"} + for edge in self.network.edges: + edge_props = {"color": get_edge_color(edge[2]['effect'], self.styles)} + edge = apply_edge_style(edge_props, self.styles['default']['edges']) + objects.append({ + "id": edge, + "start": edge[0], + "end": edge[1], + "properties": edge + }) - w.set_node_styles_mapping(custom_node_color_mapping) - - def custom_factor_mapping(node: Dict): - """choose random factor""" - return 5 - - w.set_node_scale_factor_mapping(custom_factor_mapping) - - def custom_label_styles_mapping(node: Dict): - """let the label be the negated purple big index""" - return { - 'text': node["properties"]["label"], - 'backgroundColor': None, - 'fontSize': 40, - 'color': '#030200', - 'shape': 'round-rectangle', - 'textAlignment': 'center' - } + w.edges = objects - w.set_node_label_mapping(custom_label_styles_mapping) + w.set_edge_color_mapping(self.custom_edge_color_mapping) + w.set_node_styles_mapping(self.custom_node_color_mapping) + w.set_node_scale_factor_mapping(self.custom_factor_mapping) + w.set_node_label_mapping(self.custom_label_styles_mapping) w.directed = directed w.graph_layout = graph_layout display(w) - def vis_comparison( - self, - int_comparison, - node_comparison, - graph_layout, - directed, - ): + def vis_comparison(self, int_comparison, node_comparison, graph_layout, directed): # creating empty object for visualization w = GraphWidget() objects = [] for idx, item in node_comparison.iterrows(): - obj = { - "id": node_comparison["node"].loc[idx], - "properties": {"label": node_comparison["node"].loc[idx], - "comparison": node_comparison["comparison"].loc[idx], }, - "color": "#ffffff", - # "styles":{"backgroundColor":"#ffffff"} - } - objects.append(obj) + node_props = {"label": item["node"], "comparison": item["comparison"]} + node = apply_node_style(node_props, self.styles['default']['nodes']) + node = set_custom_node_color(node, get_comparison_color(item["comparison"], self.styles, 'nodes')) + objects.append({ + "id": item["node"], + "properties": node, + "color": node['color'], + "styles": {"backgroundColor": node['fillcolor']} + }) w.nodes = objects # filling w with edges objects = [] for index, row in int_comparison.iterrows(): - obj = { - "id": int_comparison["comparison"].loc[index], - "properties": { - "comparison": int_comparison["comparison"].loc[index]}, - "start": int_comparison["source"].loc[index], - "end": int_comparison["target"].loc[index] - } - objects.append(obj) + edge_props = {"comparison": row["comparison"]} + edge = apply_edge_style(edge_props, self.styles['default']['edges']) + edge = set_custom_edge_color(edge, get_comparison_color(row["comparison"], self.styles, 'edges')) + objects.append({ + "id": row["comparison"], + "start": row["source"], + "end": row["target"], + "properties": edge + }) w.edges = objects - def custom_node_color_mapping(node: Dict): - if node['properties']['comparison'] == "Unique to Network 1": - return {"color": "#f5f536"} - elif node['properties']['comparison'] == "Unique to Network 2": - return {"color": "#36f55f"} - elif node['properties']['comparison'] == "Common": - return {"color": "#3643f5"} - - w.set_node_styles_mapping(custom_node_color_mapping) - - def custom_factor_mapping(node: Dict): - """choose random factor""" - return 5 - - w.set_node_scale_factor_mapping(custom_factor_mapping) - - def custom_label_styles_mapping(node: Dict): - """let the label be the negated purple big index""" - return { - 'text': node["id"], - 'backgroundColor': None, - 'fontSize': 20, - 'color': '#030200', - 'position': 'center', - 'maximumWidth': 130, - 'wrapping': 'word', - 'textAlignment': 'center' - } - - w.set_node_label_mapping(custom_label_styles_mapping) - - def custom_edge_color_mapping(edge: Dict): - if edge['id'] == "Unique to Network 1": - return "#e3941e" - elif edge['id'] == "Unique to Network 2": - return "#36f55f" - elif edge['id'] == "Common": - return "#3643f5" - elif edge['id'] == "Conflicting": - return "#ffcc00" - - w.set_edge_color_mapping(custom_edge_color_mapping) + w.set_edge_color_mapping(self.custom_edge_color_mapping) + w.set_node_styles_mapping(self.custom_node_color_mapping) + w.set_node_scale_factor_mapping(self.custom_factor_mapping) + w.set_node_label_mapping(self.custom_label_styles_mapping) w.directed = directed w.graph_layout = graph_layout display(w) + @staticmethod + def custom_edge_color_mapping(edge: Dict): + return edge["properties"]["color"] -# Example usage -# network is assumed to be a custom object with edges, nodes, and initial_nodes attributes. -# Here is a mock example of how to create such an object. You should replace it with your actual network data. + @staticmethod + def custom_node_color_mapping(node: Dict): + return {"color": node["color"]} + @staticmethod + def custom_factor_mapping(node: Dict): + return 5 + + @staticmethod + def custom_label_styles_mapping(node: Dict): + label_style = get_styles()['default']['labels'].copy() + label_style['text'] = node["properties"]["label"] + return label_style + + +# Example usage # class MockNetwork: # def __init__(self): # self.edges = pd.DataFrame({ @@ -183,5 +136,5 @@ def custom_edge_color_mapping(edge: Dict): # self.initial_nodes = ['A', 'B'] # # network = MockNetwork() -# visualizer = NetworkXVisualizer(network) -# visualizer.render(view=True) +# visualizer = YFilesVisualizer(network) +# visualizer.yfiles_visual(graph_layout="hierarchic", directed=True) diff --git a/networkcommons/_visual/yfiles_styles.py b/networkcommons/_visual/yfiles_styles.py new file mode 100644 index 0000000..888e6c1 --- /dev/null +++ b/networkcommons/_visual/yfiles_styles.py @@ -0,0 +1,133 @@ +def get_styles(): + """ + Return a dictionary containing styles for different types of networks specific to yFiles visualizations. + """ + styles = { + 'default': { + 'nodes': { + 'shape': 'round-rectangle', + 'color': '#cccccc', + 'style': 'filled', + 'fillcolor': '#cccccc', + 'label': '', + 'penwidth': 1, + 'fontSize': 12, + 'textAlignment': 'center' + }, + 'edges': { + 'color': 'gray', + 'penwidth': 1 + }, + 'labels': { + 'text': '', + 'backgroundColor': None, + 'fontSize': 12, + 'color': '#030200', + 'shape': 'round-rectangle', + 'textAlignment': 'center' + } + }, + 'highlight': { + 'nodes': { + 'shape': 'round-rectangle', + 'color': '#ffcc00', + 'style': 'filled', + 'fillcolor': '#ffcc00', + 'label': '', + 'penwidth': 2, + 'fontSize': 12, + 'textAlignment': 'center' + }, + 'edges': { + 'color': '#ffcc00', + 'penwidth': 2 + }, + 'labels': { + 'text': '', + 'backgroundColor': None, + 'fontSize': 12, + 'color': '#030200', + 'shape': 'round-rectangle', + 'textAlignment': 'center' + } + }, + 'comparison': { + 'nodes': { + 'Unique to Network 1': '#f5f536', + 'Unique to Network 2': '#36f55f', + 'Common': '#3643f5' + }, + 'edges': { + 'Unique to Network 1': '#e3941e', + 'Unique to Network 2': '#36f55f', + 'Common': '#3643f5', + 'Conflicting': '#ffcc00' + } + }, + 'effects': { + 'stimulation': 'green', + 'inhibition': 'red', + 'form complex': 'blue', + 'bimodal': 'purple', + 'undefined': 'gray', + 'default': 'black' + } + # Additional styles can be added here + } + + return styles + + +def apply_node_style(node, style): + """ + Apply the given style to a node. + + Args: + node (dict): The node to style. + style (dict): The style dictionary with node attributes. + """ + for attr, value in style.items(): + node[attr] = value + return node + + +def apply_edge_style(edge, style): + """ + Apply the given style to an edge. + + Args: + edge (dict): The edge to style. + style (dict): The style dictionary with edge attributes. + """ + for attr, value in style.items(): + edge[attr] = value + return edge + + +def get_edge_color(effect, styles): + """ + Get the color for an edge based on its effect. + + Args: + effect (str): The effect type of the edge. + styles (dict): The styles dictionary. + + Returns: + str: The color for the edge. + """ + return styles['effects'].get(effect, styles['effects']['default']) + + +def get_comparison_color(category, styles, element_type='nodes'): + """ + Get the color for nodes or edges based on the comparison category. + + Args: + category (str): The comparison category. + styles (dict): The styles dictionary. + element_type (str): The type of element ('nodes' or 'edges'). + + Returns: + str: The color for the element. + """ + return styles['comparison'][element_type].get(category, styles['default'][element_type]['color']) From 4ff4ccdec61c9bc6a8a8be2c778d3044a3247429 Mon Sep 17 00:00:00 2001 From: Olga Ivanova Date: Tue, 2 Jul 2024 16:18:48 +0200 Subject: [PATCH 3/4] lollipop plots for networks stat + basic rna-seq stat (networkx based) --- networkcommons/__init__.py | 6 +- networkcommons/_visual/__init__.py | 2 +- networkcommons/_visual/network_stats.py | 128 ++++++++ networkcommons/_visual/rnaseq.py | 295 ++++++++++++++++++ .../_visual/{networkx.py => vis_networkx.py} | 0 networkcommons/_visual/yfiles.py | 1 + tests/test_visual_networkx.py | 2 +- 7 files changed, 429 insertions(+), 5 deletions(-) create mode 100644 networkcommons/_visual/network_stats.py create mode 100644 networkcommons/_visual/rnaseq.py rename networkcommons/_visual/{networkx.py => vis_networkx.py} (100%) diff --git a/networkcommons/__init__.py b/networkcommons/__init__.py index d1141df..7e1d5f5 100644 --- a/networkcommons/__init__.py +++ b/networkcommons/__init__.py @@ -33,7 +33,7 @@ 'visual', ] -import lazy_import +#import lazy_import from ._metadata import __author__, __version__ from ._session import log, _log, session @@ -49,6 +49,6 @@ 'visual', ] -for _mod in _MODULES: +#for _mod in _MODULES: - globals()[_mod] = lazy_import.lazy_module(f'{__name__}.{_mod}') +# globals()[_mod] = lazy_import.lazy_module(f'{__name__}.{_mod}') diff --git a/networkcommons/_visual/__init__.py b/networkcommons/_visual/__init__.py index 36120f4..1676ec2 100644 --- a/networkcommons/_visual/__init__.py +++ b/networkcommons/_visual/__init__.py @@ -17,5 +17,5 @@ Visualization methods for networks and analyses. """ -from . import networkx as networkx +from . import vis_networkx as networkx from . import yfiles as yfiles diff --git a/networkcommons/_visual/network_stats.py b/networkcommons/_visual/network_stats.py new file mode 100644 index 0000000..2aabdc0 --- /dev/null +++ b/networkcommons/_visual/network_stats.py @@ -0,0 +1,128 @@ +import pandas as pd +import matplotlib.pyplot as plt +import networkx as nx +from typing import List, Dict + + +def plot_n_nodes_edges( + networks: Dict[str, nx.DiGraph], + filepath=None, + render=False, + orientation='vertical', + color_palette='Set2', + size=10, + linewidth=2, + marker='o', + show_nodes=True, + show_edges=True +): + """ + Plot the number of nodes and edges in the networks using a lollipop plot. + + Args: + networks (Dict[str, nx.DiGraph]): A dictionary of network names and their corresponding graphs. + filepath (str): Path to save the plot. Default is None. + render (bool): Whether to display the plot. Default is False. + orientation (str): 'vertical' or 'horizontal'. Default is 'vertical'. + color_palette (str): Matplotlib color palette. Default is 'Set2'. + size (int): Size of the markers. Default is 10. + linewidth (int): Line width of the lollipops. Default is 2. + marker (str): Marker style for the lollipops. Default is 'o'. + show_nodes (bool): Whether to show nodes count. Default is True. + show_edges (bool): Whether to show edges count. Default is True. + """ + if not show_nodes and not show_edges: + raise ValueError("At least one of 'show_nodes' or 'show_edges' must be True.") + + # Set the color palette + palette = plt.get_cmap(color_palette) + colors = palette.colors if hasattr(palette, 'colors') else palette(range(len(networks))) + + fig, ax = plt.subplots(figsize=(12, 8)) + + for idx, (network_name, network) in enumerate(networks.items()): + # Get the number of nodes and edges + n_nodes = len(network.nodes) + n_edges = len(network.edges) + categories = [] + values = [] + + if show_nodes: + categories.append('Nodes') + values.append(n_nodes) + if show_edges: + categories.append('Edges') + values.append(n_edges) + + color = colors[idx % len(colors)] + + if orientation == 'vertical': + positions = [f"{network_name} {cat}" for cat in categories] + ax.vlines(x=positions, ymin=0, ymax=values, color=color, linewidth=linewidth, label=network_name) + ax.scatter(positions, values, color=color, s=size ** 2, marker=marker, zorder=3) + + # Annotate the values + for i, value in enumerate(values): + offset = size * 0.1 if value < 10 else size * 0.2 + ax.text(positions[i], value + offset, str(value), ha='center', va='bottom', fontsize=size) + else: + positions = [f"{network_name} {cat}" for cat in categories] + ax.hlines(y=positions, xmin=0, xmax=values, color=color, linewidth=linewidth, label=network_name) + ax.scatter(values, positions, color=color, s=size ** 2, marker=marker, zorder=3) + + # Annotate the values + for i, value in enumerate(values): + offset = size * 0.1 if value < 10 else size * 0.2 + ax.text(value + offset, positions[i], str(value), va='center', ha='left', fontsize=size) + + # Set the axis labels + if orientation == 'vertical': + ax.set_xlabel("Network and Type") + ax.set_ylabel("Count") + else: + ax.set_ylabel("Network and Type") + ax.set_xlabel("Count") + + # Set the title depending on the categories + title = "Number of Nodes and Edges" + if show_nodes and not show_edges: + title = "Number of Nodes" + elif show_edges and not show_nodes: + title = "Number of Edges" + ax.set_title(title) + + # Add a legend + ax.legend() + + # Save the plot + if filepath is not None: + plt.savefig(filepath) + + # Render the plot + if render: + plt.show() + + +# Test the function with sample networks +# if __name__ == "__main__": +# # Create sample directed graphs +# G1 = nx.DiGraph() +# G1.add_nodes_from(range(10)) # Adding 10 nodes +# G1.add_edges_from([(i, i + 1) for i in range(9)]) # Adding 9 edges +# +# G2 = nx.DiGraph() +# G2.add_nodes_from(range(15)) # Adding 15 nodes +# G2.add_edges_from([(i, i + 1) for i in range(14)]) # Adding 14 edges +# +# G3 = nx.DiGraph() +# G3.add_nodes_from(range(5)) # Adding 5 nodes +# G3.add_edges_from([(i, (i + 1) % 5) for i in range(5)]) # Adding 5 edges +# +# G4 = nx.DiGraph() +# G4.add_nodes_from(range(20)) # Adding 20 nodes +# G4.add_edges_from([(i, i + 1) for i in range(19)]) # Adding 19 edges +# +# networks = {'Network1': G1, 'Network2': G2, 'Network3': G3, 'Network4': G4} +# +# plot_n_nodes_edges(networks, filepath="nodes_edges_plot.png", render=True, orientation='horizontal', +# color_palette='Set2', size=12, linewidth=2, marker='o', show_nodes=True, show_edges=True) diff --git a/networkcommons/_visual/rnaseq.py b/networkcommons/_visual/rnaseq.py new file mode 100644 index 0000000..84aa6dd --- /dev/null +++ b/networkcommons/_visual/rnaseq.py @@ -0,0 +1,295 @@ +import pandas as pd +import numpy as np +import matplotlib.pyplot as plt +import seaborn as sns +from sklearn.decomposition import PCA +from typing import List + + +def build_volcano_plot( + data: pd.DataFrame, + log2fc: str, + pval: str, + pval_threshold: float = 0.05, + log2fc_threshold: float = 1, + title: str = "Volcano Plot", + xlabel: str = "log2 Fold Change", + ylabel: str = "-log10(p-value)", + colors: tuple = ("gray", "red"), + alpha: float = 0.7, + size: int = 50, + save: bool = False, + output_dir: str = "." +): + data['-log10(pval)'] = -np.log10(data[pval]) + data['significant'] = (data[pval] < pval_threshold) & (abs(data[log2fc]) >= log2fc_threshold) + + fig, ax = plt.subplots(figsize=(10, 10)) + + ax.scatter( + data.loc[~data['significant'], log2fc], + data.loc[~data['significant'], '-log10(pval)'], + c=colors[0], + alpha=alpha, + s=size, + label='Non-significant' + ) + + ax.scatter( + data.loc[data['significant'], log2fc], + data.loc[data['significant'], '-log10(pval)'], + c=colors[1], + alpha=alpha, + s=size, + label='Significant' + ) + + ax.axhline( + -np.log10(pval_threshold), + color="blue", + linestyle="--" + ) + + ax.axvline( + log2fc_threshold, + color="blue", + linestyle="--" + ) + ax.axvline( + -log2fc_threshold, + color="blue", + linestyle="--" + ) + + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + ax.set_title(title) + ax.legend() + + if save: + plt.savefig(f"{output_dir}/volcano_plot.png") + + plt.show() + + +def build_ma_plot( + data: pd.DataFrame, + log2fc: str, + mean_exp: str, + log2fc_threshold: float = 1, + title: str = "MA Plot", + xlabel: str = "Mean Expression", + ylabel: str = "log2 Fold Change", + colors: tuple = ("gray", "red"), + alpha: float = 0.7, + size: int = 50, + save: bool = False, + output_dir: str = "." +): + data['significant'] = abs(data[log2fc]) >= log2fc_threshold + + fig, ax = plt.subplots(figsize=(10, 10)) + + ax.scatter( + data.loc[~data['significant'], mean_exp], + data.loc[~data['significant'], log2fc], + c=colors[0], + alpha=alpha, + s=size, + label='Non-significant' + ) + + ax.scatter( + data.loc[data['significant'], mean_exp], + data.loc[data['significant'], log2fc], + c=colors[1], + alpha=alpha, + s=size, + label='Significant' + ) + + ax.axhline( + 0, + color="blue", + linestyle="--" + ) + + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + ax.set_title(title) + ax.legend() + + if save: + plt.savefig(f"{output_dir}/ma_plot.png") + + plt.show() + + +def build_pca_plot( + data: pd.DataFrame, + title: str = "PCA Plot", + xlabel: str = "PC1", + ylabel: str = "PC2", + alpha: float = 0.7, + size: int = 50, + save: bool = False, + output_dir: str = "." +): + pca = PCA(n_components=2) + principal_components = pca.fit_transform(data) + pca_df = pd.DataFrame(data=principal_components, columns=['PC1', 'PC2']) + + fig, ax = plt.subplots(figsize=(10, 10)) + + ax.scatter( + pca_df['PC1'], + pca_df['PC2'], + alpha=alpha, + s=size + ) + + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + ax.set_title(title) + + if save: + plt.savefig(f"{output_dir}/pca_plot.png") + + plt.show() + + +def build_heatmap_with_tree( + data: pd.DataFrame, + top_n: int = 50, + value_column: str = 'log2FoldChange_condition_1', + conditions: List[str] = None, + title: str = "Heatmap of Top Differentially Expressed Genes", + save: bool = False, + output_dir: str = "." +): + """ + Build a heatmap with hierarchical clustering for the top differentially expressed genes across multiple conditions. + + Args: + data (pd.DataFrame): DataFrame containing RNA-seq results. + top_n (int): Number of top differentially expressed genes to include in the heatmap. + value_column (str): Column name for the values to rank and select the top genes. + conditions (List[str]): List of condition columns to include in the heatmap. + title (str): Title of the plot. + save (bool): Whether to save the plot. Default is False. + output_dir (str): Directory to save the plot. Default is ".". + """ + if conditions is None: + raise ValueError("Conditions must be provided as a list of column names.") + + # Select top differentially expressed genes + top_genes = data.nlargest(top_n, value_column).index + top_data = data.loc[top_genes, conditions] + + # Create the clustermap + g = sns.clustermap(top_data, cmap="viridis", cbar=True, fmt=".2f", linewidths=.5) + + plt.title(title) + plt.ylabel("Gene") + plt.xlabel("Condition") + + if save: + plt.savefig(f"{output_dir}/heatmap_with_tree.png") + + plt.show() + + +# Test the functions with the generated dataset +# Example condition columns for plotting +# np.random.seed(42) +# +# # Number of genes and conditions +# num_genes = 1000 +# num_conditions = 4 +# +# # Generate gene names +# genes = [f"gene_{i}" for i in range(num_genes)] +# +# # Generate mean expression values for each gene +# mean_expression = np.random.uniform(1, 100, num_genes) +# +# # Generate log2 fold changes and p-values for each condition +# log2_fold_changes = { +# f'log2FoldChange_condition_{i + 1}': np.random.randn(num_genes) for i in range(num_conditions) +# } +# p_values = { +# f'pvalue_condition_{i + 1}': np.random.uniform(0, 1, num_genes) for i in range(num_conditions) +# } +# +# # Combine data into a DataFrame +# data = pd.DataFrame({ +# 'meanExpression': mean_expression, +# **log2_fold_changes, +# **p_values +# }, index=genes) +# +# # Display the first few rows of the generated data +# print(data.head()) +# +# # Save the data to a CSV file for later use (optional) +# data.to_csv('simulated_rnaseq_data.csv') +# +# conditions = [f'log2FC_condition_{i + 1}' for i in range(num_conditions)] +# +# # Plot the Volcano plot for the first condition +# build_volcano_plot( +# data=data, +# log2fc='log2FoldChange_condition_1', +# pval='pvalue_condition_1', +# pval_threshold=0.05, +# log2fc_threshold=1, +# title="Sample Volcano Plot", +# xlabel="log2 Fold Change", +# ylabel="-log10(p-value)", +# colors=("gray", "red"), +# alpha=0.7, +# size=20, +# save=False, +# output_dir="." +# ) +# +# # Plot the MA plot for the first condition +# build_ma_plot( +# data=data, +# log2fc='log2FoldChange_condition_1', +# mean_exp='meanExpression', +# log2fc_threshold=1, +# title="Sample MA Plot", +# xlabel="Mean Expression", +# ylabel="log2 Fold Change", +# colors=("gray", "red"), +# alpha=0.7, +# size=20, +# save=False, +# output_dir="." +# ) +# +# # Plot the PCA plot for all conditions +# pca_conditions = [f'log2FoldChange_condition_{i + 1}' for i in range(num_conditions)] +# build_pca_plot( +# data=data[pca_conditions], +# title="Sample PCA Plot", +# xlabel="PC1", +# ylabel="PC2", +# alpha=0.7, +# size=20, +# save=False, +# output_dir="." +# ) +# +# +# # Plot the Heatmap with hierarchical clustering for the top differentially expressed genes across all conditions +# build_heatmap_with_tree( +# data=data, +# top_n=50, +# value_column='log2FoldChange_condition_1', +# conditions=pca_conditions, +# title="Sample Heatmap of Top Differentially Expressed Genes", +# save=False, +# output_dir="." +# ) diff --git a/networkcommons/_visual/networkx.py b/networkcommons/_visual/vis_networkx.py similarity index 100% rename from networkcommons/_visual/networkx.py rename to networkcommons/_visual/vis_networkx.py diff --git a/networkcommons/_visual/yfiles.py b/networkcommons/_visual/yfiles.py index 65f5c57..7fa2151 100644 --- a/networkcommons/_visual/yfiles.py +++ b/networkcommons/_visual/yfiles.py @@ -11,6 +11,7 @@ #set_custom_edge_color, get_edge_color, get_comparison_color) + class YFilesVisualizer: def __init__(self, network): diff --git a/tests/test_visual_networkx.py b/tests/test_visual_networkx.py index d00152b..97c731e 100644 --- a/tests/test_visual_networkx.py +++ b/tests/test_visual_networkx.py @@ -2,7 +2,7 @@ import networkx as nx -import networkcommons._visual.networkx as _vis +import networkcommons._visual.vis_networkx as _vis class TestVisualizeNetwork: From 42a420fec4facc3a4215ba0786f8388c0de5190d Mon Sep 17 00:00:00 2001 From: Olga Ivanova Date: Tue, 2 Jul 2024 16:22:30 +0200 Subject: [PATCH 4/4] quick fixes for imports --- networkcommons/__init__.py | 6 +++--- networkcommons/_visual/__init__.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/networkcommons/__init__.py b/networkcommons/__init__.py index 7e1d5f5..d1141df 100644 --- a/networkcommons/__init__.py +++ b/networkcommons/__init__.py @@ -33,7 +33,7 @@ 'visual', ] -#import lazy_import +import lazy_import from ._metadata import __author__, __version__ from ._session import log, _log, session @@ -49,6 +49,6 @@ 'visual', ] -#for _mod in _MODULES: +for _mod in _MODULES: -# globals()[_mod] = lazy_import.lazy_module(f'{__name__}.{_mod}') + globals()[_mod] = lazy_import.lazy_module(f'{__name__}.{_mod}') diff --git a/networkcommons/_visual/__init__.py b/networkcommons/_visual/__init__.py index 1676ec2..fb98e96 100644 --- a/networkcommons/_visual/__init__.py +++ b/networkcommons/_visual/__init__.py @@ -17,5 +17,5 @@ Visualization methods for networks and analyses. """ -from . import vis_networkx as networkx -from . import yfiles as yfiles +from . import vis_networkx as vis_networkx +from . import yfiles as vis_yfiles \ No newline at end of file