diff --git a/pymc/model/core.py b/pymc/model/core.py index 2f73c9ee243..ff378efa25b 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -1050,7 +1050,14 @@ def set_dim(self, name: str, new_length: int, coord_values: Sequence | None = No expected=new_length, ) self._coords[name] = tuple(coord_values) - self.dim_lengths[name].set_value(new_length) + dim_length = self.dim_lengths[name] + if not isinstance(dim_length, SharedVariable): + raise TypeError( + f"The dim_length of `{name}` must be a `SharedVariable` " + "(created through `coords` to allow updating). " + f"The current type is: {type(dim_length)}" + ) + dim_length.set_value(new_length) return def initial_point(self, random_seed: SeedSequenceSeed = None) -> dict[str, np.ndarray]: @@ -1102,8 +1109,8 @@ def set_data( shared_object = self[name] if not isinstance(shared_object, SharedVariable): raise TypeError( - f"The variable `{name}` must be a `SharedVariable`" - " (created through `pm.Data()` or `pm.Data(mutable=True)`) to allow updating. " + f"The variable `{name}` must be a `SharedVariable` " + "(created through `pm.Data()` to allow updating.) " f"The current type is: {type(shared_object)}" ) diff --git a/pymc/model/transform/optimization.py b/pymc/model/transform/optimization.py index 651d22310c8..634a1a76173 100644 --- a/pymc/model/transform/optimization.py +++ b/pymc/model/transform/optimization.py @@ -11,16 +11,27 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sequence + from pytensor import clone_replace from pytensor.compile import SharedVariable from pytensor.graph import FunctionGraph from pytensor.tensor import constant +from pytensor.tensor.sharedvar import TensorSharedVariable +from pytensor.tensor.variable import TensorConstant from pymc import Model from pymc.model.fgraph import ModelFreeRV, fgraph_from_model, model_from_fgraph -def freeze_dims_and_data(model: Model) -> Model: +def _constant_from_shared(shared: SharedVariable) -> TensorConstant: + assert isinstance(shared, TensorSharedVariable) + return constant(shared.get_value(), name=shared.name, dtype=shared.type.dtype) + + +def freeze_dims_and_data( + model: Model, dims: Sequence[str] | None = None, data: Sequence[str] | None = None +) -> Model: """Recreate a Model with fixed RV dimensions and Data values. The dimensions of the pre-existing RVs will no longer follow changes to the coordinates. @@ -30,41 +41,60 @@ def freeze_dims_and_data(model: Model) -> Model: This transformation may allow more performant sampling, or compiling model functions to backends that are more restrictive about dynamic shapes such as JAX. + + Parameters + ---------- + model : Model + The model where to freeze dims and data. + dims : Sequence of str, optional + The dimensions to freeze. + If None, all dimensions are frozen. Pass an empty list to avoid freezing any dimension. + data : Sequence of str, optional + The data to freeze. + If None, all data are frozen. Pass an empty list to avoid freezing any data. + + Returns + ------- + Model + A new model with the specified dimensions and data frozen. """ fg, memo = fgraph_from_model(model) + if dims is None: + dims = tuple(model.dim_lengths.keys()) + if data is None: + data = tuple(model.named_vars.keys()) + # Replace mutable dim lengths and data by constants - frozen_vars = { - memo[dim_length]: constant( - dim_length.get_value(), name=dim_length.name, dtype=dim_length.type.dtype - ) - for dim_length in model.dim_lengths.values() + frozen_replacements = { + memo[dim_length]: _constant_from_shared(dim_length) + for dim_length in (model.dim_lengths[dim_name] for dim_name in dims) if isinstance(dim_length, SharedVariable) } - frozen_vars |= { - memo[data_var].owner.inputs[0]: constant( - data_var.get_value(), name=data_var.name, dtype=data_var.type.dtype - ) - for data_var in model.named_vars.values() - if isinstance(data_var, SharedVariable) + frozen_replacements |= { + memo[datum].owner.inputs[0]: _constant_from_shared(datum) + for datum in (model.named_vars[datum_name] for datum_name in data) + if isinstance(datum, SharedVariable) } - old_outs, coords = fg.outputs, fg._coords # type: ignore + old_outs, old_coords, old_dim_lenghts = fg.outputs, fg._coords, fg._dim_lengths # type: ignore # Rebuild strict will force the recreation of RV nodes with updated static types - new_outs = clone_replace(old_outs, replace=frozen_vars, rebuild_strict=False) # type: ignore + new_outs = clone_replace(old_outs, replace=frozen_replacements, rebuild_strict=False) # type: ignore for old_out, new_out in zip(old_outs, new_outs): new_out.name = old_out.name fg = FunctionGraph(outputs=new_outs, clone=False) - fg._coords = coords # type: ignore + fg._coords = old_coords # type: ignore + fg._dim_lengths = { # type: ignore + dim: frozen_replacements.get(dim_length, dim_length) + for dim, dim_length in old_dim_lenghts.items() + } # Recreate value variables from new RVs to propagate static types to logp graphs replacements = {} for node in fg.apply_nodes: if not isinstance(node.op, ModelFreeRV): continue - rv, old_value, *dims = node.inputs - if dims is None: - continue + rv, old_value, *_ = node.inputs transform = node.op.transform if transform is None: new_value = rv.type() diff --git a/tests/model/transform/test_optimization.py b/tests/model/transform/test_optimization.py index 01e1d394f0b..8198f213114 100644 --- a/tests/model/transform/test_optimization.py +++ b/tests/model/transform/test_optimization.py @@ -11,17 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import numpy as np +import pytest + +from pytensor.compile import SharedVariable from pytensor.graph import Constant +from pymc import Deterministic from pymc.data import Data from pymc.distributions import HalfNormal, Normal from pymc.model import Model from pymc.model.transform.optimization import freeze_dims_and_data -def test_freeze_existing_rv_dims_and_data(): +def test_freeze_dims_and_data(): with Model(coords={"test_dim": range(5)}) as m: - std = Data("std", [1]) + std = Data("test_data", [1]) x = HalfNormal("x", std, dims=("test_dim",)) y = Normal("y", shape=x.shape[0] + 1) @@ -34,18 +39,96 @@ def test_freeze_existing_rv_dims_and_data(): assert y_logp.type.shape == (None,) frozen_m = freeze_dims_and_data(m) - std, x, y = frozen_m["std"], frozen_m["x"], frozen_m["y"] + data, x, y = frozen_m["test_data"], frozen_m["x"], frozen_m["y"] x_logp, y_logp = frozen_m.logp(sum=False) - assert isinstance(std, Constant) + assert isinstance(data, Constant) assert x.type.shape == (5,) assert y.type.shape == (6,) assert x_logp.type.shape == (5,) assert y_logp.type.shape == (6,) + # Test trying to update a frozen data or dim raises an informative error + with frozen_m: + with pytest.raises(TypeError, match="The variable `test_data` must be a `SharedVariable`"): + frozen_m.set_data("test_data", values=[2]) + with pytest.raises( + TypeError, match="The dim_length of `test_dim` must be a `SharedVariable`" + ): + frozen_m.set_dim("test_dim", new_length=6, coord_values=range(6)) + + # Test we can still update original model + with m: + m.set_data("test_data", values=[2]) + m.set_dim("test_dim", new_length=6, coord_values=range(6)) + assert m["test_data"].get_value() == [2] + assert m.dim_lengths["test_dim"].get_value() == 6 -def test_freeze_rv_dims_nothing_to_change(): + +def test_freeze_dims_nothing_to_change(): with Model(coords={"test_dim": range(5)}) as m: x = HalfNormal("x", shape=(5,)) y = Normal("y", shape=x.shape[0] + 1) assert m.point_logps() == freeze_dims_and_data(m).point_logps() + + +def test_freeze_dims_and_data_subset(): + with Model(coords={"dim1": range(3), "dim2": range(5)}) as m: + data1 = Data("data1", [1, 2, 3], dims="dim1") + data2 = Data("data2", [1, 2, 3, 4, 5], dims="dim2") + var1 = Normal("var1", dims="dim1") + var2 = Normal("var2", dims="dim2") + x = data1 * var1 + y = data2 * var2 + det = Deterministic("det", x[:, None] + y[None, :]) + + assert det.type.shape == (None, None) + + new_m = freeze_dims_and_data(m, dims=["dim1"], data=[]) + assert new_m["det"].type.shape == (3, None) + assert isinstance(new_m.dim_lengths["dim1"], Constant) and new_m.dim_lengths["dim1"].data == 3 + assert isinstance(new_m.dim_lengths["dim2"], SharedVariable) + assert isinstance(new_m["data1"], SharedVariable) + assert isinstance(new_m["data2"], SharedVariable) + + new_m = freeze_dims_and_data(m, dims=["dim2"], data=[]) + assert new_m["det"].type.shape == (None, 5) + assert isinstance(new_m.dim_lengths["dim1"], SharedVariable) + assert isinstance(new_m.dim_lengths["dim2"], Constant) and new_m.dim_lengths["dim2"].data == 5 + assert isinstance(new_m["data1"], SharedVariable) + assert isinstance(new_m["data2"], SharedVariable) + + new_m = freeze_dims_and_data(m, dims=["dim1", "dim2"], data=[]) + assert new_m["det"].type.shape == (3, 5) + assert isinstance(new_m.dim_lengths["dim1"], Constant) and new_m.dim_lengths["dim1"].data == 3 + assert isinstance(new_m.dim_lengths["dim2"], Constant) and new_m.dim_lengths["dim2"].data == 5 + assert isinstance(new_m["data1"], SharedVariable) + assert isinstance(new_m["data2"], SharedVariable) + + new_m = freeze_dims_and_data(m, dims=[], data=["data1"]) + assert new_m["det"].type.shape == (3, None) + assert isinstance(new_m.dim_lengths["dim1"], SharedVariable) + assert isinstance(new_m.dim_lengths["dim2"], SharedVariable) + assert isinstance(new_m["data1"], Constant) and np.all(new_m["data1"].data == [1, 2, 3]) + assert isinstance(new_m["data2"], SharedVariable) + + new_m = freeze_dims_and_data(m, dims=[], data=["data2"]) + assert new_m["det"].type.shape == (None, 5) + assert isinstance(new_m.dim_lengths["dim1"], SharedVariable) + assert isinstance(new_m.dim_lengths["dim2"], SharedVariable) + assert isinstance(new_m["data1"], SharedVariable) + assert isinstance(new_m["data2"], Constant) and np.all(new_m["data2"].data == [1, 2, 3, 4, 5]) + + new_m = freeze_dims_and_data(m, dims=[], data=["data1", "data2"]) + assert new_m["det"].type.shape == (3, 5) + assert isinstance(new_m.dim_lengths["dim1"], SharedVariable) + assert isinstance(new_m.dim_lengths["dim2"], SharedVariable) + assert isinstance(new_m["data1"], Constant) and np.all(new_m["data1"].data == [1, 2, 3]) + assert isinstance(new_m["data2"], Constant) and np.all(new_m["data2"].data == [1, 2, 3, 4, 5]) + + new_m = freeze_dims_and_data(m, dims=["dim1"], data=["data2"]) + assert new_m["det"].type.shape == (3, 5) + assert isinstance(new_m.dim_lengths["dim1"], Constant) and new_m.dim_lengths["dim1"].data == 3 + assert isinstance(new_m.dim_lengths["dim2"], SharedVariable) + assert isinstance(new_m["data1"], SharedVariable) + assert isinstance(new_m["data2"], Constant) and np.all(new_m["data2"].data == [1, 2, 3, 4, 5])