Skip to content

Commit

Permalink
Added networkx export functionality (#6046)
Browse files Browse the repository at this point in the history
* Added networkx export functionality

Apologies for delay but finally got around to updating the latest version of model_graph.py to include networkx export function
as discussed in #5677

* Update model_graph.py

restoring whitespace #5677 (comment)

* Corrected cluster(subgraph) behaviour

subgraph(cluster) attributes are now stored on each node that is a member of that subgraph(cluster).

* Update pymc/model_graph.py

Co-authored-by: Oriol Abril-Pla <[email protected]>

* Correcting docstring spacing

* Updated to include new function model_to_networkx

* linted and checked by pre-commit

* added test for model_to_networkx function

* added function import

* added model_to_networkx

* corrected formatting

* more linting

* Update test_model_graph.py

* Update __init__.py

* Redid changes

* added function to models.rst

* corrected indent issue

indent issue cause by merge resolution

* corrected spelling error

* added networkx to requirements

* redid pre-commit and ran it twice

* Update pymc/model_graph.py

Co-authored-by: Oriol Abril-Pla <[email protected]>

* Update pymc/model_graph.py

Co-authored-by: Oriol Abril-Pla <[email protected]>

* Update pymc/model_graph.py

Co-authored-by: Oriol Abril-Pla <[email protected]>

* Update pymc/model_graph.py

Co-authored-by: Oriol Abril-Pla <[email protected]>

* Update pymc/model_graph.py

Co-authored-by: Oriol Abril-Pla <[email protected]>

* Removed trailing whitespace

Removed trailing whitespace from lines 312, 314,

* removing more trailing whitespace

Co-authored-by: Oriol Abril-Pla <[email protected]>
Co-authored-by: Joni Pelham <[email protected]>
  • Loading branch information
3 people authored Aug 15, 2022
1 parent 906fcdc commit 9883915
Show file tree
Hide file tree
Showing 9 changed files with 200 additions and 8 deletions.
1 change: 1 addition & 0 deletions conda-envs/environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies:
- pandas>=0.24.0
- pip
- python-graphviz
- networkx
- scipy>=1.4.1
- typing-extensions>=3.7.4
# Extra dependencies for dev, testing and docs build
Expand Down
1 change: 1 addition & 0 deletions conda-envs/environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies:
- numpy>=1.15.0
- pandas>=0.24.0
- python-graphviz
- networkx
- scipy>=1.4.1
- typing-extensions>=3.7.4
# Extra dependencies for testing
Expand Down
1 change: 1 addition & 0 deletions conda-envs/windows-environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies:
- pandas>=0.24.0
- pip
- python-graphviz
- networkx
- scipy>=1.4.1
- typing-extensions>=3.7.4
# Extra dependencies for dev, testing and docs build
Expand Down
1 change: 1 addition & 0 deletions conda-envs/windows-environment-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ dependencies:
- pandas>=0.24.0
- pip
- python-graphviz
- networkx
- scipy>=1.4.1
- typing-extensions>=3.7.4
# Extra dependencies for testing
Expand Down
1 change: 1 addition & 0 deletions docs/source/api/model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Model creation and inspection

Model
model_to_graphviz
model_to_networkx
modelcontext

Others
Expand Down
2 changes: 1 addition & 1 deletion pymc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __set_compiler_flags():
probit,
)
from pymc.model import *
from pymc.model_graph import model_to_graphviz
from pymc.model_graph import model_to_graphviz, model_to_networkx
from pymc.plots import *
from pymc.printing import *
from pymc.sampling import *
Expand Down
144 changes: 138 additions & 6 deletions pymc/model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ def make_compute_graph(

return input_map

def _make_node(self, var_name, graph, *, formatting: str = "plain"):
"""Attaches the given variable to a graphviz Digraph"""
def _make_node(self, var_name, graph, *, nx=False, cluster=False, formatting: str = "plain"):
"""Attaches the given variable to a graphviz or networkx Digraph"""
v = self.model[var_name]

shape = None
Expand Down Expand Up @@ -168,7 +168,13 @@ def _make_node(self, var_name, graph, *, formatting: str = "plain"):
"label": label,
}

graph.node(var_name.replace(":", "&"), **kwargs)
if cluster:
kwargs["cluster"] = cluster

if nx:
graph.add_node(var_name.replace(":", "&"), **kwargs)
else:
graph.node(var_name.replace(":", "&"), **kwargs)

def _eval(self, var):
return function([], var, mode="FAST_COMPILE")()
Expand All @@ -178,7 +184,6 @@ def get_plates(self, var_names: Optional[Iterable[VarName]] = None) -> Dict[str,
Just groups by the shape of the underlying distribution. Will be wrong
if there are two plates with the same shape.
Returns
-------
dict
Expand Down Expand Up @@ -234,9 +239,134 @@ def make_graph(self, var_names: Optional[Iterable[VarName]] = None, formatting:

return graph

def make_networkx(
self, var_names: Optional[Iterable[VarName]] = None, formatting: str = "plain"
):
"""Make networkx Digraph of PyMC model
Returns
-------
networkx.Digraph
"""
try:
import networkx
except ImportError:
raise ImportError(
"This function requires the python library networkx, along with binaries. "
"The easiest way to install all of this is by running\n\n"
"\tconda install networkx"
)
graphnetwork = networkx.DiGraph(name=self.model.name)
for plate_label, all_var_names in self.get_plates(var_names).items():
if plate_label:
# # must be preceded by 'cluster' to get a box around it

subgraphnetwork = networkx.DiGraph(name="cluster" + plate_label, label=plate_label)

for var_name in all_var_names:
self._make_node(
var_name,
subgraphnetwork,
nx=True,
cluster="cluster" + plate_label,
formatting=formatting,
)
for sgn in subgraphnetwork.nodes:
networkx.set_node_attributes(
subgraphnetwork,
{sgn: {"labeljust": "r", "labelloc": "b", "style": "rounded"}},
)
node_data = {
e[0]: e[1]
for e in graphnetwork.nodes(data=True) & subgraphnetwork.nodes(data=True)
}

graphnetwork = networkx.compose(graphnetwork, subgraphnetwork)
networkx.set_node_attributes(graphnetwork, node_data)
graphnetwork.graph["name"] = self.model.name
else:
for var_name in all_var_names:

self._make_node(var_name, graphnetwork, nx=True, formatting=formatting)

for child, parents in self.make_compute_graph(var_names=var_names).items():
# parents is a set of rv names that preceed child rv nodes
for parent in parents:
graphnetwork.add_edge(parent.replace(":", "&"), child.replace(":", "&"))
return graphnetwork


def model_to_networkx(
model=None,
*,
var_names: Optional[Iterable[VarName]] = None,
formatting: str = "plain",
):
"""Produce a networkx Digraph from a PyMC model.
Requires networkx, which may be installed most easily with::
conda install networkx
Alternatively, you may install using pip with::
pip install networkx
See https://networkx.org/documentation/stable/ for more information.
Parameters
----------
model : Model
The model to plot. Not required when called from inside a modelcontext.
var_names : iterable of str, optional
Subset of variables to be plotted that identify a subgraph with respect to the entire model graph
formatting : str, optional
one of { "plain" }
Examples
--------
How to plot the graph of the model.
.. code-block:: python
import numpy as np
from pymc import HalfCauchy, Model, Normal, model_to_networkx
J = 8
y = np.array([28, 8, -3, 7, -1, 1, 18, 12])
sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18])
with Model() as schools:
eta = Normal("eta", 0, 1, shape=J)
mu = Normal("mu", 0, sigma=1e6)
tau = HalfCauchy("tau", 25)
theta = mu + tau * eta
obs = Normal("obs", theta, sigma=sigma, observed=y)
model_to_networkx(schools)
"""
if not "plain" in formatting:

raise ValueError(f"Unsupported formatting for graph nodes: '{formatting}'. See docstring.")

if formatting != "plain":
warnings.warn(
"Formattings other than 'plain' are currently not supported.",
UserWarning,
stacklevel=2,
)
model = pm.modelcontext(model)
return ModelGraph(model).make_networkx(var_names=var_names, formatting=formatting)


def model_to_graphviz(
model=None, *, var_names: Optional[Iterable[VarName]] = None, formatting: str = "plain"
model=None,
*,
var_names: Optional[Iterable[VarName]] = None,
formatting: str = "plain",
):
"""Produce a graphviz Digraph from a PyMC model.
Expand Down Expand Up @@ -286,7 +416,9 @@ def model_to_graphviz(
raise ValueError(f"Unsupported formatting for graph nodes: '{formatting}'. See docstring.")
if formatting != "plain":
warnings.warn(
"Formattings other than 'plain' are currently not supported.", UserWarning, stacklevel=2
"Formattings other than 'plain' are currently not supported.",
UserWarning,
stacklevel=2,
)
model = pm.modelcontext(model)
return ModelGraph(model).make_graph(var_names=var_names, formatting=formatting)
56 changes: 55 additions & 1 deletion pymc/tests/test_model_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,64 @@

import pymc as pm

from pymc.model_graph import ModelGraph, model_to_graphviz
from pymc.model_graph import ModelGraph, model_to_graphviz, model_to_networkx
from pymc.tests.helpers import SeededTest


def school_model():
"""
Schools model to use in testing model_to_networkx function
"""
J = 8
y = np.array([28, 8, -3, 7, -1, 1, 18, 12])
sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18])
with pm.Model() as schools:
eta = pm.Normal("eta", 0, 1, shape=J)
mu = pm.Normal("mu", 0, sigma=1e6)
tau = pm.HalfCauchy("tau", 25)
theta = mu + tau * eta
obs = pm.Normal("obs", theta, sigma=sigma, observed=y)
return schools


class BaseModelNXTest(SeededTest):
network_model = {
"graph_attr_dict_factory": dict,
"node_dict_factory": dict,
"node_attr_dict_factory": dict,
"adjlist_outer_dict_factory": dict,
"adjlist_inner_dict_factory": dict,
"edge_attr_dict_factory": dict,
"graph": {"name": "", "label": "8"},
"_node": {
"eta": {
"shape": "ellipse",
"style": "rounded",
"label": "eta\n~\nNormal",
"cluster": "cluster8",
"labeljust": "r",
"labelloc": "b",
},
"obs": {
"shape": "ellipse",
"style": "rounded",
"label": "obs\n~\nNormal",
"cluster": "cluster8",
"labeljust": "r",
"labelloc": "b",
},
"tau": {"shape": "ellipse", "style": None, "label": "tau\n~\nHalfCauchy"},
"mu": {"shape": "ellipse", "style": None, "label": "mu\n~\nNormal"},
},
"_adj": {"eta": {"obs": {}}, "obs": {}, "tau": {"obs": {}}, "mu": {"obs": {}}},
"_pred": {"eta": {}, "obs": {"tau": {}, "eta": {}, "mu": {}}, "tau": {}, "mu": {}},
"_succ": {"eta": {"obs": {}}, "obs": {}, "tau": {"obs": {}}, "mu": {"obs": {}}},
}

def test_networkx(self):
assert self.network_model == model_to_networkx(school_model()).__dict__


def radon_model():
"""Similar in shape to the Radon model"""
n_homes = 919
Expand Down
1 change: 1 addition & 0 deletions scripts/generate_pip_deps_from_conda.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"mkl-service",
"numba",
"python-graphviz",
"networkx",
"blas",
"jax",
}
Expand Down

0 comments on commit 9883915

Please sign in to comment.