From 43b84f56dcea9d4edf7b532bf4a428f12d10e5b3 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 27 Feb 2025 12:13:22 +0100 Subject: [PATCH] Supervisor not needed for JAX rewrites As it no longer includes inplace operations. --- pymc/sampling/jax.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/pymc/sampling/jax.py b/pymc/sampling/jax.py index 193554380c..6c01192825 100644 --- a/pymc/sampling/jax.py +++ b/pymc/sampling/jax.py @@ -30,7 +30,7 @@ from arviz.data.base import make_attrs from jax.lax import scan from numpy.typing import ArrayLike -from pytensor.compile import SharedVariable, Supervisor, mode +from pytensor.compile import SharedVariable, mode from pytensor.graph.basic import graph_inputs from pytensor.graph.fg import FunctionGraph from pytensor.graph.replace import clone_replace @@ -127,18 +127,6 @@ def get_jaxified_graph( graph = _replace_shared_variables(outputs) if outputs is not None else None fgraph = FunctionGraph(inputs=inputs, outputs=graph, clone=True) - # We need to add a Supervisor to the fgraph to be able to run the - # JAX sequential optimizer without warnings. We made sure there - # are no mutable input variables, so we only need to check for - # "destroyers". This should be automatically handled by PyTensor - # once https://github.com/aesara-devs/aesara/issues/637 is fixed. - fgraph.attach_feature( - Supervisor( - input - for input in fgraph.inputs - if not (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([input])) - ) - ) mode.JAX.optimizer.rewrite(fgraph) # We now jaxify the optimized fgraph