Skip to content

Commit

Permalink
Revert questionable changes
Browse files Browse the repository at this point in the history
  • Loading branch information
fonnesbeck committed Dec 11, 2024
1 parent 2f245d7 commit a8a8a6b
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions pymc_extras/statespace/core/statespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
FILTER_OUTPUT_DIMS,
FILTER_OUTPUT_TYPES,
JITTER_DEFAULT,
LONG_MATRIX_NAMES,
MATRIX_DIMS,
MATRIX_NAMES,
OBS_STATE_DIM,
SHOCK_DIM,
SHORT_NAME_TO_LONG,
Expand Down Expand Up @@ -750,7 +750,7 @@ def _register_matrices_with_pymc_model(self) -> list[pt.TensorVariable]:
matrices = self.unpack_statespace()

registered_matrices = []
for i, (matrix, name) in enumerate(zip(matrices, LONG_MATRIX_NAMES)):
for i, (matrix, name) in enumerate(zip(matrices, MATRIX_NAMES)):
time_varying_ndim = 2 if name in VECTOR_VALUED else 3
if not getattr(pm_mod, name, None):
shape, dims = self._get_matrix_shape_and_dims(name)
Expand Down Expand Up @@ -1471,7 +1471,7 @@ def sample_statespace_matrices(
_verify_group(group)

if matrix_names is None:
matrix_names = LONG_MATRIX_NAMES
matrix_names = MATRIX_NAMES
elif isinstance(matrix_names, str):
matrix_names = [matrix_names]

Expand All @@ -1484,7 +1484,7 @@ def sample_statespace_matrices(

self._insert_data_variables()
matrices = self.unpack_statespace()
for short_name, matrix in zip(LONG_MATRIX_NAMES, matrices):
for short_name, matrix in zip(MATRIX_NAMES, matrices):
long_name = SHORT_NAME_TO_LONG[short_name]
if (long_name in matrix_names) or (short_name in matrix_names):
name = long_name if long_name in matrix_names else short_name
Expand Down Expand Up @@ -2038,10 +2038,7 @@ def forecast(
}

matrices = graph_replace(matrices, replace=sub_dict, strict=True)
[
setattr(matrix, "name", name)
for name, matrix in zip(LONG_MATRIX_NAMES[2:], matrices)
]
[setattr(matrix, "name", name) for name, matrix in zip(MATRIX_NAMES[2:], matrices)]

_ = LinearGaussianStateSpace(
"forecast",
Expand Down

0 comments on commit a8a8a6b

Please sign in to comment.