Skip to content

Commit

Permalink
fix: Use clone_replace instead of graph_replace
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed May 2, 2024
1 parent 5e6cfc5 commit 34e593d
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions python/nutpie/compile_pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,10 +341,7 @@ def _make_functions(model):
if use_split:
variables = pt.split(joined, splits, len(splits))
else:
variables = [
joined[slice_val].reshape(shape)
for slice_val, shape in zip(joined_slices, joined_shapes)
]
variables = [joined[slice_val] for slice_val in zip(joined_slices)]

replacements = {
model.rvs_to_values[var]: value.reshape(shape) if len(shape) != 1 else value
Expand All @@ -355,7 +352,7 @@ def _make_functions(model):
)
}

(logp, grad) = pytensor.graph_replace([logp, grad], replacements)
(logp, grad) = pytensor.clone_replace([logp, grad], replacements)

# We should avoid compiling the function, and optimize only
with model:
Expand Down

0 comments on commit 34e593d

Please sign in to comment.