diff --git a/pymc_extras/statespace/core/statespace.py b/pymc_extras/statespace/core/statespace.py index 12a64c0f..4c9651d7 100644 --- a/pymc_extras/statespace/core/statespace.py +++ b/pymc_extras/statespace/core/statespace.py @@ -34,8 +34,8 @@ FILTER_OUTPUT_DIMS, FILTER_OUTPUT_TYPES, JITTER_DEFAULT, - LONG_MATRIX_NAMES, MATRIX_DIMS, + MATRIX_NAMES, OBS_STATE_DIM, SHOCK_DIM, SHORT_NAME_TO_LONG, @@ -750,7 +750,7 @@ def _register_matrices_with_pymc_model(self) -> list[pt.TensorVariable]: matrices = self.unpack_statespace() registered_matrices = [] - for i, (matrix, name) in enumerate(zip(matrices, LONG_MATRIX_NAMES)): + for i, (matrix, name) in enumerate(zip(matrices, MATRIX_NAMES)): time_varying_ndim = 2 if name in VECTOR_VALUED else 3 if not getattr(pm_mod, name, None): shape, dims = self._get_matrix_shape_and_dims(name) @@ -1471,7 +1471,7 @@ def sample_statespace_matrices( _verify_group(group) if matrix_names is None: - matrix_names = LONG_MATRIX_NAMES + matrix_names = MATRIX_NAMES elif isinstance(matrix_names, str): matrix_names = [matrix_names] @@ -1484,7 +1484,7 @@ def sample_statespace_matrices( self._insert_data_variables() matrices = self.unpack_statespace() - for short_name, matrix in zip(LONG_MATRIX_NAMES, matrices): + for short_name, matrix in zip(MATRIX_NAMES, matrices): long_name = SHORT_NAME_TO_LONG[short_name] if (long_name in matrix_names) or (short_name in matrix_names): name = long_name if long_name in matrix_names else short_name @@ -2038,10 +2038,7 @@ def forecast( } matrices = graph_replace(matrices, replace=sub_dict, strict=True) - [ - setattr(matrix, "name", name) - for name, matrix in zip(LONG_MATRIX_NAMES[2:], matrices) - ] + [setattr(matrix, "name", name) for name, matrix in zip(MATRIX_NAMES[2:], matrices)] _ = LinearGaussianStateSpace( "forecast",