Skip to content

Commit

Permalink
use existing function
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 committed Jan 25, 2025
1 parent b36e573 commit 7f71397
Showing 1 changed file with 3 additions and 26 deletions.
29 changes: 3 additions & 26 deletions pymc/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,13 @@
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Variable
from pytensor.raise_op import Assert
from pytensor.scalar import Cast
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.random.basic import IntegersRV
from pytensor.tensor.random.var import RandomGeneratorSharedVariable
from pytensor.tensor.type import TensorType
from pytensor.tensor.variable import TensorConstant, TensorVariable

import pymc as pm

from pymc.logprob.utils import rvs_in_graph
from pymc.pytensorf import GeneratorOp, convert_data, smarttypeX
from pymc.vartypes import isgenerator

Expand Down Expand Up @@ -149,22 +147,8 @@ def __str__(self):
return "Minibatch"


def first_inputs(r):
if not r.owner:
return

inputs = r.owner.inputs

if not inputs:
return

first_input = inputs[0]
yield first_input
yield from first_inputs(first_input)


def has_random_ancestor(r):
return any(isinstance(i, RandomGeneratorSharedVariable) for i in first_inputs(r))
return rvs_in_graph(r) != set()


def is_valid_observed(v) -> bool:
Expand All @@ -177,14 +161,7 @@ def is_valid_observed(v) -> bool:
return True

return (
# The only PyTensor operation we allow on observed data is type casting
# Although we could allow for any graph that does not depend on other RVs
(
isinstance(v.owner.op, Elemwise)
and isinstance(v.owner.op.scalar_op, Cast)
and is_valid_observed(v.owner.inputs[0])
)
or not has_random_ancestor(v)
not has_random_ancestor(v)
# Or Minibatch
or (
isinstance(v.owner.op, MinibatchOp)
Expand Down

0 comments on commit 7f71397

Please sign in to comment.