Skip to content

Commit

Permalink
Allow debug evaling IR logp graphs
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jan 30, 2025
1 parent 268e13b commit 2fed5fd
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 2 deletions.
3 changes: 2 additions & 1 deletion pymc/logprob/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,8 @@ def make_node(self, rv, value):
return Apply(self, [rv, value], [rv.type(name=rv.name)])

def perform(self, node, inputs, out):
raise NotImplementedError("ValuedVar should not be present in the final graph!")
warnings.warn("ValuedVar should not be present in the final graph!")
out[0][0] = inputs[0]

def infer_shape(self, fgraph, node, input_shapes):
return [input_shapes[0]]
Expand Down
4 changes: 3 additions & 1 deletion pymc/logprob/transform_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings

from collections.abc import Sequence

Expand Down Expand Up @@ -40,7 +41,8 @@ def make_node(self, tran_value: TensorVariable, value: TensorVariable):
return Apply(self, [tran_value, value], [tran_value.type()])

def perform(self, node, inputs, outputs):
raise NotImplementedError("These `Op`s should be removed from graphs used for computation.")
warnings.warn("TransformedValue should not be present in the final graph!")
outputs[0][0] = inputs[0]

def infer_shape(self, fgraph, node, input_shapes):
return [input_shapes[0]]
Expand Down
23 changes: 23 additions & 0 deletions tests/logprob/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,3 +436,26 @@ def test_ir_rewrite_does_not_disconnect_valued_rvs():
logp_b.eval({a_value: np.pi, b_value: np.e}),
stats.norm.logpdf(np.e, np.pi * 8, 1),
)


def test_ir_ops_can_be_evaluated_with_warning():
_eval_values = [None, None]

def my_logp(value, lam):
nonlocal _eval_values
_eval_values[0] = value.eval()
_eval_values[1] = lam.eval({"lam_log__": -1.5})
return value * lam

with pm.Model() as m:
lam = pm.Exponential("lam")
pm.CustomDist("y", lam, logp=my_logp, observed=[0, 1, 2])

with pytest.warns(
UserWarning, match="TransformedValue should not be present in the final graph"
):
with pytest.warns(UserWarning, match="ValuedVar should not be present in the final graph"):
m.logp()

assert _eval_values[0].sum() == 3
assert _eval_values[1] == np.exp(-1.5)

0 comments on commit 2fed5fd

Please sign in to comment.