Skip to content

Commit

Permalink
Supervisor not needed for JAX rewrites
Browse files Browse the repository at this point in the history
As it no longer includes inplace operations.
  • Loading branch information
ricardoV94 committed Feb 27, 2025
1 parent 450e7f6 commit 43b84f5
Showing 1 changed file with 1 addition and 13 deletions.
14 changes: 1 addition & 13 deletions pymc/sampling/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from arviz.data.base import make_attrs
from jax.lax import scan
from numpy.typing import ArrayLike
from pytensor.compile import SharedVariable, Supervisor, mode
from pytensor.compile import SharedVariable, mode
from pytensor.graph.basic import graph_inputs
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.replace import clone_replace
Expand Down Expand Up @@ -127,18 +127,6 @@ def get_jaxified_graph(
graph = _replace_shared_variables(outputs) if outputs is not None else None

fgraph = FunctionGraph(inputs=inputs, outputs=graph, clone=True)
# We need to add a Supervisor to the fgraph to be able to run the
# JAX sequential optimizer without warnings. We made sure there
# are no mutable input variables, so we only need to check for
# "destroyers". This should be automatically handled by PyTensor
# once https://github.com/aesara-devs/aesara/issues/637 is fixed.
fgraph.attach_feature(
Supervisor(
input
for input in fgraph.inputs
if not (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([input]))
)
)
mode.JAX.optimizer.rewrite(fgraph)

# We now jaxify the optimized fgraph
Expand Down

0 comments on commit 43b84f5

Please sign in to comment.