Skip to content

Commit

Permalink
renamed conversors corneto<->networkx
Browse files Browse the repository at this point in the history
  • Loading branch information
vicpaton committed Jul 3, 2024
1 parent cdfce23 commit 9b31209
Show file tree
Hide file tree
Showing 2 changed files with 4,228 additions and 196 deletions.
4,383 changes: 4,195 additions & 188 deletions docs/src/vignettes/1_simple_example.ipynb

Large diffs are not rendered by default.

41 changes: 33 additions & 8 deletions networkcommons/methods/_causal.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from __future__ import annotations

__all__ = [
'convert_cornetograph',
'to_cornetograph',
'to_networkx',
'run_corneto_carnival',
]

Expand All @@ -30,7 +31,7 @@
cn_nx = lazy_import.lazy_module('corneto.contrib.networkx')


def convert_cornetograph(graph):
def to_cornetograph(graph):
"""
Convert a networkx graph to a corneto graph, if needed.
Expand All @@ -43,11 +44,38 @@ def convert_cornetograph(graph):
if isinstance(graph, cn._graph.Graph):
corneto_graph = graph
elif isinstance(graph, (nx.Graph, nx.DiGraph)):
corneto_graph = cn_nx.networkx_to_corneto_graph(graph)
# substitute 'sign' for 'interaction' in the graph
for u, v, data in graph.edges(data=True):
data['interaction'] = data.pop('sign')

corneto_graph = networkx_to_corneto_graph(graph)

return corneto_graph


def to_networkx(graph, skip_unsupported_edges=True):
"""
Convert a corneto graph to a networkx graph, if needed.
Args:
graph (cn.Graph): The corneto graph.
Returns:
nx.Graph: The networkx graph.
"""
if isinstance(graph, nx.Graph) or isinstance(graph, nx.DiGraph):
networkx_graph = graph
elif isinstance(graph, cn._graph.Graph):
networkx_graph = corneto_graph_to_networkx(
graph,
skip_unsupported_edges=skip_unsupported_edges)
# rename interaction for sign
for u, v, data in networkx_graph.edges(data=True):
data['sign'] = data.pop('interaction')

return networkx_graph


def run_corneto_carnival(network,
source_dict,
target_dict,
Expand All @@ -68,7 +96,7 @@ def run_corneto_carnival(network,
nx.Graph: The subnetwork containing the paths found by CARNIVAL.
list: A list containing the paths found by CARNIVAL.
"""
corneto_net = convert_cornetograph(network)
corneto_net = to_cornetograph(network)

problem, graph = cn.methods.runVanillaCarnival(
perturbations=source_dict,
Expand All @@ -83,10 +111,7 @@ def run_corneto_carnival(network,
cn.methods.carnival.get_selected_edges(problem, graph),
)

network_nx = cn_nx.corneto_graph_to_networkx(
network_sol,
skip_unsupported_edges=True,
)
network_nx = to_networkx(network_sol, skip_unsupported_edges=True)

network_nx.remove_nodes_from(['_s', '_pert_c0', '_meas_c0'])

Expand Down

0 comments on commit 9b31209

Please sign in to comment.