diff --git a/conda-envs/environment-dev.yml b/conda-envs/environment-dev.yml index b339441f2d1..1446ca53871 100644 --- a/conda-envs/environment-dev.yml +++ b/conda-envs/environment-dev.yml @@ -17,6 +17,7 @@ dependencies: - pandas>=0.24.0 - pip - python-graphviz +- networkx - scipy>=1.4.1 - typing-extensions>=3.7.4 # Extra dependencies for dev, testing and docs build diff --git a/conda-envs/environment-test.yml b/conda-envs/environment-test.yml index 7d26572b78d..e4fba70d9a6 100644 --- a/conda-envs/environment-test.yml +++ b/conda-envs/environment-test.yml @@ -19,6 +19,7 @@ dependencies: - numpy>=1.15.0 - pandas>=0.24.0 - python-graphviz +- networkx - scipy>=1.4.1 - typing-extensions>=3.7.4 # Extra dependencies for testing diff --git a/conda-envs/windows-environment-dev.yml b/conda-envs/windows-environment-dev.yml index 92bd42d1ee5..013b107d33b 100644 --- a/conda-envs/windows-environment-dev.yml +++ b/conda-envs/windows-environment-dev.yml @@ -17,6 +17,7 @@ dependencies: - pandas>=0.24.0 - pip - python-graphviz +- networkx - scipy>=1.4.1 - typing-extensions>=3.7.4 # Extra dependencies for dev, testing and docs build diff --git a/conda-envs/windows-environment-test.yml b/conda-envs/windows-environment-test.yml index 7df8bb0a97d..1f3432ffd07 100644 --- a/conda-envs/windows-environment-test.yml +++ b/conda-envs/windows-environment-test.yml @@ -20,6 +20,7 @@ dependencies: - pandas>=0.24.0 - pip - python-graphviz +- networkx - scipy>=1.4.1 - typing-extensions>=3.7.4 # Extra dependencies for testing diff --git a/docs/source/api/model.rst b/docs/source/api/model.rst index 81bf12711ab..0c288978aae 100644 --- a/docs/source/api/model.rst +++ b/docs/source/api/model.rst @@ -10,6 +10,7 @@ Model creation and inspection Model model_to_graphviz + model_to_networkx modelcontext Others diff --git a/pymc/__init__.py b/pymc/__init__.py index 405b05d0fc6..2ad6d70ae6f 100644 --- a/pymc/__init__.py +++ b/pymc/__init__.py @@ -65,7 +65,7 @@ def __set_compiler_flags(): probit, ) from pymc.model import * -from pymc.model_graph import model_to_graphviz +from pymc.model_graph import model_to_graphviz, model_to_networkx from pymc.plots import * from pymc.printing import * from pymc.sampling import * diff --git a/pymc/model_graph.py b/pymc/model_graph.py index 12bbf57e611..8e28aaa9716 100644 --- a/pymc/model_graph.py +++ b/pymc/model_graph.py @@ -125,8 +125,8 @@ def make_compute_graph( return input_map - def _make_node(self, var_name, graph, *, formatting: str = "plain"): - """Attaches the given variable to a graphviz Digraph""" + def _make_node(self, var_name, graph, *, nx=False, cluster=False, formatting: str = "plain"): + """Attaches the given variable to a graphviz or networkx Digraph""" v = self.model[var_name] shape = None @@ -168,7 +168,13 @@ def _make_node(self, var_name, graph, *, formatting: str = "plain"): "label": label, } - graph.node(var_name.replace(":", "&"), **kwargs) + if cluster: + kwargs["cluster"] = cluster + + if nx: + graph.add_node(var_name.replace(":", "&"), **kwargs) + else: + graph.node(var_name.replace(":", "&"), **kwargs) def _eval(self, var): return function([], var, mode="FAST_COMPILE")() @@ -178,7 +184,6 @@ def get_plates(self, var_names: Optional[Iterable[VarName]] = None) -> Dict[str, Just groups by the shape of the underlying distribution. Will be wrong if there are two plates with the same shape. - Returns ------- dict @@ -234,9 +239,134 @@ def make_graph(self, var_names: Optional[Iterable[VarName]] = None, formatting: return graph + def make_networkx( + self, var_names: Optional[Iterable[VarName]] = None, formatting: str = "plain" + ): + """Make networkx Digraph of PyMC model + + Returns + ------- + networkx.Digraph + """ + try: + import networkx + except ImportError: + raise ImportError( + "This function requires the python library networkx, along with binaries. " + "The easiest way to install all of this is by running\n\n" + "\tconda install networkx" + ) + graphnetwork = networkx.DiGraph(name=self.model.name) + for plate_label, all_var_names in self.get_plates(var_names).items(): + if plate_label: + # # must be preceded by 'cluster' to get a box around it + + subgraphnetwork = networkx.DiGraph(name="cluster" + plate_label, label=plate_label) + + for var_name in all_var_names: + self._make_node( + var_name, + subgraphnetwork, + nx=True, + cluster="cluster" + plate_label, + formatting=formatting, + ) + for sgn in subgraphnetwork.nodes: + networkx.set_node_attributes( + subgraphnetwork, + {sgn: {"labeljust": "r", "labelloc": "b", "style": "rounded"}}, + ) + node_data = { + e[0]: e[1] + for e in graphnetwork.nodes(data=True) & subgraphnetwork.nodes(data=True) + } + + graphnetwork = networkx.compose(graphnetwork, subgraphnetwork) + networkx.set_node_attributes(graphnetwork, node_data) + graphnetwork.graph["name"] = self.model.name + else: + for var_name in all_var_names: + + self._make_node(var_name, graphnetwork, nx=True, formatting=formatting) + + for child, parents in self.make_compute_graph(var_names=var_names).items(): + # parents is a set of rv names that preceed child rv nodes + for parent in parents: + graphnetwork.add_edge(parent.replace(":", "&"), child.replace(":", "&")) + return graphnetwork + + +def model_to_networkx( + model=None, + *, + var_names: Optional[Iterable[VarName]] = None, + formatting: str = "plain", +): + """Produce a networkx Digraph from a PyMC model. + + Requires networkx, which may be installed most easily with:: + + conda install networkx + + Alternatively, you may install using pip with:: + + pip install networkx + + See https://networkx.org/documentation/stable/ for more information. + + Parameters + ---------- + model : Model + The model to plot. Not required when called from inside a modelcontext. + var_names : iterable of str, optional + Subset of variables to be plotted that identify a subgraph with respect to the entire model graph + formatting : str, optional + one of { "plain" } + + Examples + -------- + How to plot the graph of the model. + + .. code-block:: python + + import numpy as np + from pymc import HalfCauchy, Model, Normal, model_to_networkx + + J = 8 + y = np.array([28, 8, -3, 7, -1, 1, 18, 12]) + sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18]) + + with Model() as schools: + + eta = Normal("eta", 0, 1, shape=J) + mu = Normal("mu", 0, sigma=1e6) + tau = HalfCauchy("tau", 25) + + theta = mu + tau * eta + + obs = Normal("obs", theta, sigma=sigma, observed=y) + + model_to_networkx(schools) + """ + if not "plain" in formatting: + + raise ValueError(f"Unsupported formatting for graph nodes: '{formatting}'. See docstring.") + + if formatting != "plain": + warnings.warn( + "Formattings other than 'plain' are currently not supported.", + UserWarning, + stacklevel=2, + ) + model = pm.modelcontext(model) + return ModelGraph(model).make_networkx(var_names=var_names, formatting=formatting) + def model_to_graphviz( - model=None, *, var_names: Optional[Iterable[VarName]] = None, formatting: str = "plain" + model=None, + *, + var_names: Optional[Iterable[VarName]] = None, + formatting: str = "plain", ): """Produce a graphviz Digraph from a PyMC model. @@ -286,7 +416,9 @@ def model_to_graphviz( raise ValueError(f"Unsupported formatting for graph nodes: '{formatting}'. See docstring.") if formatting != "plain": warnings.warn( - "Formattings other than 'plain' are currently not supported.", UserWarning, stacklevel=2 + "Formattings other than 'plain' are currently not supported.", + UserWarning, + stacklevel=2, ) model = pm.modelcontext(model) return ModelGraph(model).make_graph(var_names=var_names, formatting=formatting) diff --git a/pymc/tests/test_model_graph.py b/pymc/tests/test_model_graph.py index f13f25647dd..c2c504ffaee 100644 --- a/pymc/tests/test_model_graph.py +++ b/pymc/tests/test_model_graph.py @@ -20,10 +20,64 @@ import pymc as pm -from pymc.model_graph import ModelGraph, model_to_graphviz +from pymc.model_graph import ModelGraph, model_to_graphviz, model_to_networkx from pymc.tests.helpers import SeededTest +def school_model(): + """ + Schools model to use in testing model_to_networkx function + """ + J = 8 + y = np.array([28, 8, -3, 7, -1, 1, 18, 12]) + sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18]) + with pm.Model() as schools: + eta = pm.Normal("eta", 0, 1, shape=J) + mu = pm.Normal("mu", 0, sigma=1e6) + tau = pm.HalfCauchy("tau", 25) + theta = mu + tau * eta + obs = pm.Normal("obs", theta, sigma=sigma, observed=y) + return schools + + +class BaseModelNXTest(SeededTest): + network_model = { + "graph_attr_dict_factory": dict, + "node_dict_factory": dict, + "node_attr_dict_factory": dict, + "adjlist_outer_dict_factory": dict, + "adjlist_inner_dict_factory": dict, + "edge_attr_dict_factory": dict, + "graph": {"name": "", "label": "8"}, + "_node": { + "eta": { + "shape": "ellipse", + "style": "rounded", + "label": "eta\n~\nNormal", + "cluster": "cluster8", + "labeljust": "r", + "labelloc": "b", + }, + "obs": { + "shape": "ellipse", + "style": "rounded", + "label": "obs\n~\nNormal", + "cluster": "cluster8", + "labeljust": "r", + "labelloc": "b", + }, + "tau": {"shape": "ellipse", "style": None, "label": "tau\n~\nHalfCauchy"}, + "mu": {"shape": "ellipse", "style": None, "label": "mu\n~\nNormal"}, + }, + "_adj": {"eta": {"obs": {}}, "obs": {}, "tau": {"obs": {}}, "mu": {"obs": {}}}, + "_pred": {"eta": {}, "obs": {"tau": {}, "eta": {}, "mu": {}}, "tau": {}, "mu": {}}, + "_succ": {"eta": {"obs": {}}, "obs": {}, "tau": {"obs": {}}, "mu": {"obs": {}}}, + } + + def test_networkx(self): + assert self.network_model == model_to_networkx(school_model()).__dict__ + + def radon_model(): """Similar in shape to the Radon model""" n_homes = 919 diff --git a/scripts/generate_pip_deps_from_conda.py b/scripts/generate_pip_deps_from_conda.py index f3d15fab45c..cbdc7791fa4 100755 --- a/scripts/generate_pip_deps_from_conda.py +++ b/scripts/generate_pip_deps_from_conda.py @@ -51,6 +51,7 @@ "mkl-service", "numba", "python-graphviz", + "networkx", "blas", "jax", }