Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JOSS review - Remarks on code #377

Open
gdalle opened this issue Sep 12, 2024 · 3 comments
Open

JOSS review - Remarks on code #377

gdalle opened this issue Sep 12, 2024 · 3 comments

Comments

@gdalle
Copy link

gdalle commented Sep 12, 2024

Hi and congrats on the package!

I'm one of the reviewers for the JOSS paper you submitted, so here I'll list my questions and concerns about the code itself. This issue will be updated as my reading progresses so maybe don't start answering right away.

Warning

I am not a Python developer, my main language is Julia, so I won't be able to give much advice beyond what I observed when trying out the examples. In addition, I don't know the best practices in terms of package and version management, I just use mamba and then pip inside the mamba environment.


Design

Execution

Tests

I followed the testing workflow on the home page of the documentation and it yielded one failure.

============================== test session starts ===============================
platform linux -- Python 3.12.5, pytest-8.3.3, pluggy-1.5.0
rootdir: /home/gdalle/Work/Review/dynamax
configfile: pyproject.toml
plugins: cov-5.0.0, jaxtyping-0.2.34, typeguard-2.13.3, anyio-4.4.0
collected 83 items                                                                                                                                                          

dynamax/generalized_gaussian_ssm/inference_test.py .                                                                                                                  [  1%]
dynamax/generalized_gaussian_ssm/models_test.py ...                                                                                                                   [  4%]
dynamax/hidden_markov_model/inference_test.py ...........                                                                                                             [ 18%]
dynamax/hidden_markov_model/models/test_models.py ...........F......                                                                     [ 39%]
dynamax/linear_gaussian_ssm/inference_test.py ...                                                                                        [ 43%]
dynamax/linear_gaussian_ssm/info_inference_test.py .......                                                                               [ 51%]
dynamax/linear_gaussian_ssm/models_test.py ..                                                                                            [ 54%]
dynamax/linear_gaussian_ssm/parallel_inference_test.py ............                                                                      [ 68%]
dynamax/nonlinear_gaussian_ssm/inference_ekf_test.py .....                                                                               [ 74%]
dynamax/nonlinear_gaussian_ssm/inference_ukf_test.py .                                                                                   [ 75%]
dynamax/parameters_test.py ....                                                                                                          [ 80%]
dynamax/slds/inference_test.py ..                                                                                                        [ 83%]
dynamax/utils/distributions_test.py ..........                                                                                           [ 95%]
dynamax/utils/utils_test.py ....                                                                                                         [100%]

=================================================================== FAILURES ===================================================================
__________________________________________ test_sample_and_fit[LinearRegressionHMM-kwargs11-inputs11] __________________________________________

cls = <class 'dynamax.hidden_markov_model.models.linreg_hmm.LinearRegressionHMM'>, kwargs = {'emission_dim': 3, 'input_dim': 5, 'num_states': 4}
inputs = Array([[1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
  ...1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1.]], dtype=float32)

    @pytest.mark.parametrize(["cls", "kwargs", "inputs"], CONFIGS)
    def test_sample_and_fit(cls, kwargs, inputs):
        hmm = cls(**kwargs)
        #key1, key2 = jr.split(jr.PRNGKey(int(datetime.now().timestamp())))
        key1, key2 = jr.split(jr.PRNGKey(42))
        params, param_props = hmm.initialize(key1)
        states, emissions = hmm.sample(params, key2, num_timesteps=NUM_TIMESTEPS, inputs=inputs)
        fitted_params, lps = hmm.fit_em(params, param_props, emissions, inputs=inputs, num_iters=10)
>       assert monotonically_increasing(lps, atol=1e-2, rtol=1e-2)
E       assert Array(False, dtype=bool)
E        +  where Array(False, dtype=bool) = monotonically_increasing(Array([-224.90086,        nan,        nan,        nan,        nan,\n              nan,        nan,        nan,        nan,        nan],      dtype=float32), atol=0.01, rtol=0.01)

dynamax/hidden_markov_model/models/test_models.py:39: AssertionError
------------------------------------------------------------- Captured stdout call -------------------------------------------------------------
 |████████████████████████████████████████| 100.00% [10/10 00:00<00:00]
=============================================================== warnings summary ===============================================================
dynamax/generalized_gaussian_ssm/inference.py:86
  /home/gdalle/Work/Review/dynamax/dynamax/generalized_gaussian_ssm/inference.py:86: SyntaxWarning: invalid escape sequence '\i'
    """Predict next mean and covariance under an additive-noise Gaussian filter

dynamax/hidden_markov_model/models/abstractions.py:488
  /home/gdalle/Work/Review/dynamax/dynamax/hidden_markov_model/models/abstractions.py:488: SyntaxWarning: invalid escape sequence '\m'
    """Abstract base class of Hidden Markov Models (HMMs).

dynamax/hidden_markov_model/models/gaussian_hmm.py:722
  /home/gdalle/Work/Review/dynamax/dynamax/hidden_markov_model/models/gaussian_hmm.py:722: SyntaxWarning: invalid escape sequence '\s'
    """Initialize the model parameters and their corresponding properties.

dynamax/hidden_markov_model/parallel_inference.py:15
  /home/gdalle/Work/Review/dynamax/dynamax/hidden_markov_model/parallel_inference.py:15: SyntaxWarning: invalid escape sequence '\m'
    """Filtering associative scan elements.

dynamax/hidden_markov_model/parallel_inference.py:135
  /home/gdalle/Work/Review/dynamax/dynamax/hidden_markov_model/parallel_inference.py:135: SyntaxWarning: invalid escape sequence '\m'
    """Associative scan elements $E_ij$ are vectors specifying a sample::

dynamax/generalized_gaussian_ssm/models_test.py::test_poisson_emission[key0-kwargs0]
dynamax/generalized_gaussian_ssm/models_test.py::test_poisson_emission[key0-kwargs0]
dynamax/generalized_gaussian_ssm/models_test.py::test_poisson_emission[key1-kwargs1]
dynamax/generalized_gaussian_ssm/models_test.py::test_poisson_emission[key1-kwargs1]
dynamax/generalized_gaussian_ssm/models_test.py::test_poisson_emission[key2-kwargs2]
dynamax/generalized_gaussian_ssm/models_test.py::test_poisson_emission[key2-kwargs2]
dynamax/hidden_markov_model/models/test_models.py::test_sample_and_fit[PoissonHMM-kwargs14-None]
dynamax/hidden_markov_model/models/test_models.py::test_sample_and_fit[PoissonHMM-kwargs14-None]
  /home/gdalle/miniforge3/envs/dynamax-test/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:118: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in astype is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
    return lax_numpy.astype(self, dtype, copy=copy, device=device)

dynamax/generalized_gaussian_ssm/models_test.py: 12 warnings
dynamax/hidden_markov_model/models/test_models.py: 4 warnings
  /home/gdalle/miniforge3/envs/dynamax-test/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/ops.py:339: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in array is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
    return np.array(value, dtype=dtype)

dynamax/generalized_gaussian_ssm/models_test.py: 12 warnings
dynamax/hidden_markov_model/models/test_models.py: 4 warnings
  /home/gdalle/miniforge3/envs/dynamax-test/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/random_generators.py:290: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
    minval = minval + np.zeros([1] * final_rank, dtype=dtype)

dynamax/generalized_gaussian_ssm/models_test.py: 12 warnings
dynamax/hidden_markov_model/models/test_models.py: 4 warnings
  /home/gdalle/miniforge3/envs/dynamax-test/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/random_generators.py:291: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in zeros is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
    maxval = maxval + np.zeros([1] * final_rank, dtype=dtype)

dynamax/generalized_gaussian_ssm/models_test.py: 12 warnings
dynamax/hidden_markov_model/models/test_models.py: 4 warnings
  /home/gdalle/miniforge3/envs/dynamax-test/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/random_generators.py:292: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'>  is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
    return jaxrand.uniform(key=seed, shape=shape, dtype=dtype, minval=minval,

dynamax/generalized_gaussian_ssm/models_test.py::test_poisson_emission[key0-kwargs0]
dynamax/generalized_gaussian_ssm/models_test.py::test_poisson_emission[key0-kwargs0]
dynamax/generalized_gaussian_ssm/models_test.py::test_poisson_emission[key1-kwargs1]
dynamax/generalized_gaussian_ssm/models_test.py::test_poisson_emission[key1-kwargs1]
dynamax/generalized_gaussian_ssm/models_test.py::test_poisson_emission[key2-kwargs2]
dynamax/generalized_gaussian_ssm/models_test.py::test_poisson_emission[key2-kwargs2]
dynamax/hidden_markov_model/models/test_models.py::test_sample_and_fit[PoissonHMM-kwargs14-None]
dynamax/hidden_markov_model/models/test_models.py::test_sample_and_fit[PoissonHMM-kwargs14-None]
  /home/gdalle/miniforge3/envs/dynamax-test/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/numpy_array.py:450: UserWarning: Explicitly requested dtype <class 'jax.numpy.float64'> requested in ones is not available, and will be truncated to dtype float32. To enable more dtypes, set the jax_enable_x64 configuration option or the JAX_ENABLE_X64 shell environment variable. See https://github.com/google/jax#current-gotchas for more.
    lambda shape, dtype=np.float32, name=None, layout=None: np.ones(  # pylint: disable=g-long-lambda

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=========================================================== short test summary info ============================================================
FAILED dynamax/hidden_markov_model/models/test_models.py::test_sample_and_fit[LinearRegressionHMM-kwargs11-inputs11] - assert Array(False, dtype=bool)
============================================ 1 failed, 82 passed, 85 warnings in 149.15s (0:02:29) =============================================

First tutorial

When I try to run the tutorial notebook after naive installation, it fails to generate sample data due to a Numpy 2.0 breaking change. I'm not sure why pip allowed me to install Numpy 2.0 as a dependency, since it doesn't seem to be supported yet (#366).

AttributeError                            Traceback (most recent call last)
Cell In[5], [line 22](vscode-notebook-cell:?execution_count=5&line=22)
     [15](vscode-notebook-cell:?execution_count=5&line=15) emission_means = jnp.column_stack([
     [16](vscode-notebook-cell:?execution_count=5&line=16)     jnp.cos(jnp.linspace(0, 2 * jnp.pi, true_num_states + 1))[:-1],
     [17](vscode-notebook-cell:?execution_count=5&line=17)     jnp.sin(jnp.linspace(0, 2 * jnp.pi, true_num_states + 1))[:-1],
     [18](vscode-notebook-cell:?execution_count=5&line=18)     jnp.zeros((true_num_states, emission_dim - 2)),
     [19](vscode-notebook-cell:?execution_count=5&line=19)     ])
     [20](vscode-notebook-cell:?execution_count=5&line=20) emission_covs = jnp.tile(0.1**2 * jnp.eye(emission_dim), (true_num_states, 1, 1))
---> [22](vscode-notebook-cell:?execution_count=5&line=22) true_params, _ = hmm.initialize(initial_probs=initial_probs,
     [23](vscode-notebook-cell:?execution_count=5&line=23)                                 transition_matrix=transition_matrix,
     [24](vscode-notebook-cell:?execution_count=5&line=24)                                 emission_means=emission_means,
     [25](vscode-notebook-cell:?execution_count=5&line=25)                                 emission_covariances=emission_covs)
     [27](vscode-notebook-cell:?execution_count=5&line=27) # Sample train, validation, and test data
     [28](vscode-notebook-cell:?execution_count=5&line=28) train_key, val_key, test_key = jr.split(jr.PRNGKey(0), 3)

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/hidden_markov_model/models/gaussian_hmm.py:651, in GaussianHMM.initialize(self, key, method, initial_probs, transition_matrix, emission_means, emission_covariances, emissions)
    [649](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/hidden_markov_model/models/gaussian_hmm.py:649) params["initial"], props["initial"] = self.initial_component.initialize(key1, method=method, initial_probs=initial_probs)
    [650](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/hidden_markov_model/models/gaussian_hmm.py:650) params["transitions"], props["transitions"] = self.transition_component.initialize(key2, method=method, transition_matrix=transition_matrix)
--> [651](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/hidden_markov_model/models/gaussian_hmm.py:651) params["emissions"], props["emissions"] = self.emission_component.initialize(key3, method=method, emission_means=emission_means, emission_covariances=emission_covariances, emissions=emissions)
    [652](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/hidden_markov_model/models/gaussian_hmm.py:652) return ParamsGaussianHMM(**params), ParamsGaussianHMM(**props)

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/hidden_markov_model/models/gaussian_hmm.py:83, in GaussianHMMEmissions.initialize(self, key, method, emission_means, emission_covariances, emissions)
     [81](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/hidden_markov_model/models/gaussian_hmm.py:81) elif method.lower() == "prior":
     [82](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/hidden_markov_model/models/gaussian_hmm.py:82)     this_key, key = jr.split(key)
---> [83](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/hidden_markov_model/models/gaussian_hmm.py:83)     prior = NormalInverseWishart(self.emission_prior_mean, self.emission_prior_conc,
     [84](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/hidden_markov_model/models/gaussian_hmm.py:84)                                  self.emission_prior_df, self.emission_prior_scale)
     [85](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/hidden_markov_model/models/gaussian_hmm.py:85)     (_emission_covs, _emission_means) = prior.sample(seed=this_key, sample_shape=(self.num_states,))
     [87](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/hidden_markov_model/models/gaussian_hmm.py:87) else:

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
    [230](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/decorator.py:230) if not kwsyntax:
    [231](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/decorator.py:231)     args, kw = fix(args, kw, sig)
--> [232](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/decorator.py:232) return caller(func, *(extras + args), **kw)

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
    [339](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:339) # Note: if we ever want to have things set in `self` before `__init__` is
    [340](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:340) # called, here is the place to do it.
    [341](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:341) self_._parameters = None
--> [342](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342) default_init(self_, *args, **kwargs)
    [343](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:343) # Note: if we ever want to override things set in `self` by subclass
    [344](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:344) # `__init__`, here is the place to do it.
    [345](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:345) if self_._parameters is None:
    [346](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:346)   # We prefer subclasses will set `parameters = dict(locals())` because
    [347](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:347)   # this has nearly zero overhead. However, failing to do this, we will
    [348](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:348)   # resolve the input arguments dynamically and only when needed.

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/utils/distributions.py:133, in NormalInverseWishart.__init__(self, loc, mean_concentration, df, scale)
    [129](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/utils/distributions.py:129) self._df = df
    [130](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/utils/distributions.py:130) self._scale = scale
    [132](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/utils/distributions.py:132) super(NormalInverseWishart, self).__init__([
--> [133](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/utils/distributions.py:133)     InverseWishart(df, scale),
    [134](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/utils/distributions.py:134)     lambda Sigma: tfd.MultivariateNormalFullCovariance(loc, Sigma / mean_concentration)
    [135](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/utils/distributions.py:135) ])
    [137](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/utils/distributions.py:137) self._parameters = dict(loc=loc, mean_concentration=mean_concentration, df=df, scale=scale)

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
    [230](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/decorator.py:230) if not kwsyntax:
    [231](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/decorator.py:231)     args, kw = fix(args, kw, sig)
--> [232](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/decorator.py:232) return caller(func, *(extras + args), **kw)

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
    [339](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:339) # Note: if we ever want to have things set in `self` before `__init__` is
    [340](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:340) # called, here is the place to do it.
    [341](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:341) self_._parameters = None
--> [342](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342) default_init(self_, *args, **kwargs)
    [343](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:343) # Note: if we ever want to override things set in `self` by subclass
    [344](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:344) # `__init__`, here is the place to do it.
    [345](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:345) if self_._parameters is None:
    [346](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:346)   # We prefer subclasses will set `parameters = dict(locals())` because
    [347](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:347)   # this has nearly zero overhead. However, failing to do this, we will
    [348](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:348)   # resolve the input arguments dynamically and only when needed.

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/utils/distributions.py:51, in InverseWishart.__init__(self, df, scale)
     [48](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/utils/distributions.py:48) cho_scale = jnp.linalg.cholesky(scale)
     [49](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/utils/distributions.py:49) inv_scale_tril = solve_triangular(cho_scale, eye, lower=True)
---> [51](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/utils/distributions.py:51) super().__init__(
     [52](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/utils/distributions.py:52)     tfd.WishartTriL(df, scale_tril=inv_scale_tril),
     [53](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/utils/distributions.py:53)     tfb.Chain([tfb.CholeskyOuterProduct(),
     [54](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/utils/distributions.py:54)                tfb.CholeskyToInvCholesky(),
     [55](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/utils/distributions.py:55)                tfb.Invert(tfb.CholeskyOuterProduct())]))
     [57](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/utils/distributions.py:57) self._parameters = dict(df=df, scale=scale)

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/decorator.py:232, in decorate.<locals>.fun(*args, **kw)
    [230](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/decorator.py:230) if not kwsyntax:
    [231](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/decorator.py:231)     args, kw = fix(args, kw, sig)
--> [232](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/decorator.py:232) return caller(func, *(extras + args), **kw)

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342, in _DistributionMeta.__new__.<locals>.wrapped_init(***failed resolving arguments***)
    [339](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:339) # Note: if we ever want to have things set in `self` before `__init__` is
    [340](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:340) # called, here is the place to do it.
    [341](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:341) self_._parameters = None
--> [342](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:342) default_init(self_, *args, **kwargs)
    [343](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:343) # Note: if we ever want to override things set in `self` by subclass
    [344](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:344) # `__init__`, here is the place to do it.
    [345](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:345) if self_._parameters is None:
    [346](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:346)   # We prefer subclasses will set `parameters = dict(locals())` because
    [347](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:347)   # this has nearly zero overhead. However, failing to do this, we will
    [348](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py:348)   # resolve the input arguments dynamically and only when needed.

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py:244, in _TransformedDistribution.__init__(self, distribution, bijector, kwargs_split_fn, validate_args, parameters, name)
    [238](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py:238) self._zero = tf.constant(0, dtype=tf.int32, name='zero')
    [240](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py:240) # We don't just want to check isinstance(JointDistribution) because
    [241](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py:241) # TransformedDistributions with multipart bijectors are effectively
    [242](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py:242) # joint but don't inherit from JD. The 'duck-type' test is that
    [243](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py:243) # JDs have a structured dtype.
--> [244](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py:244) dtype = self.bijector.forward_dtype(self.distribution.dtype)
    [245](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py:245) self._is_joint = tf.nest.is_nested(dtype)
    [247](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py:247) super(_TransformedDistribution, self).__init__(
    [248](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py:248)     dtype=dtype,
    [249](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py:249)     reparameterization_type=self._distribution.reparameterization_type,
   (...)
    [252](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py:252)     parameters=parameters,
    [253](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py:253)     name=name)

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1705, in Bijector.forward_dtype(self, dtype, name, **kwargs)
   [1701](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1701)   input_dtype = nest_util.broadcast_structure(
   [1702](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1702)       self.forward_min_event_ndims, self.dtype)
   [1703](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1703) else:
   [1704](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1704)   # Make sure inputs are compatible with statically-known dtype.
-> [1705](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1705)   input_dtype = nest.map_structure_up_to(
   [1706](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1706)       self.forward_min_event_ndims,
   [1707](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1707)       lambda x: dtype_util.convert_to_dtype(x, dtype=self.dtype),
   [1708](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1708)       nest_util.coerce_structure(self.forward_min_event_ndims, dtype),
   [1709](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1709)       check_types=False)
   [1711](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1711) output_dtype = self._forward_dtype(input_dtype, **kwargs)
   [1712](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1712) try:
   [1713](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1713)   # kwargs may alter dtypes themselves, but we currently require
   [1714](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1714)   # structure to be statically known.

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:324, in map_structure_up_to(shallow_structure, func, *structures, **kwargs)
    [323](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:323) def map_structure_up_to(shallow_structure, func, *structures, **kwargs):
--> [324](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:324)   return map_structure_with_tuple_paths_up_to(
    [325](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:325)       shallow_structure,
    [326](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:326)       lambda _, *args: func(*args),  # Discards path.
    [327](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:327)       *structures,
    [328](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:328)       **kwargs)

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:353, in map_structure_with_tuple_paths_up_to(shallow_structure, func, expand_composites, *structures, **kwargs)
    [350](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:350) for input_tree in structures:
    [351](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:351)   assert_shallow_structure(
    [352](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:352)       shallow_structure, input_tree, check_types=check_types)
--> [353](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:353) return dm_tree.map_structure_with_path_up_to(
    [354](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:354)     shallow_structure, func, *structures, **kwargs)

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tree/__init__.py:771, in map_structure_with_path_up_to(***failed resolving arguments***)
    [769](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tree/__init__.py:769) results = []
    [770](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tree/__init__.py:770) for path_and_values in _multiyield_flat_up_to(shallow_structure, *structures):
--> [771](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tree/__init__.py:771)   results.append(func(*path_and_values))
    [772](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tree/__init__.py:772) return unflatten_as(shallow_structure, results)

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:326, in map_structure_up_to.<locals>.<lambda>(_, *args)
    [323](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:323) def map_structure_up_to(shallow_structure, func, *structures, **kwargs):
    [324](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:324)   return map_structure_with_tuple_paths_up_to(
    [325](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:325)       shallow_structure,
--> [326](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:326)       lambda _, *args: func(*args),  # Discards path.
    [327](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:327)       *structures,
    [328](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py:328)       **kwargs)

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1707, in Bijector.forward_dtype.<locals>.<lambda>(x)
   [1701](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1701)   input_dtype = nest_util.broadcast_structure(
   [1702](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1702)       self.forward_min_event_ndims, self.dtype)
   [1703](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1703) else:
   [1704](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1704)   # Make sure inputs are compatible with statically-known dtype.
   [1705](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1705)   input_dtype = nest.map_structure_up_to(
   [1706](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1706)       self.forward_min_event_ndims,
-> [1707](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1707)       lambda x: dtype_util.convert_to_dtype(x, dtype=self.dtype),
   [1708](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1708)       nest_util.coerce_structure(self.forward_min_event_ndims, dtype),
   [1709](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1709)       check_types=False)
   [1711](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1711) output_dtype = self._forward_dtype(input_dtype, **kwargs)
   [1712](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1712) try:
   [1713](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1713)   # kwargs may alter dtypes themselves, but we currently require
   [1714](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py:1714)   # structure to be statically known.

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/internal/dtype_util.py:247, in convert_to_dtype(tensor_or_dtype, dtype, dtype_hint)
    [245](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/internal/dtype_util.py:245) elif isinstance(tensor_or_dtype, np.ndarray):
    [246](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/internal/dtype_util.py:246)   dt = base_dtype(dtype or dtype_hint or tensor_or_dtype.dtype)
--> [247](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/internal/dtype_util.py:247) elif np.issctype(tensor_or_dtype):
    [248](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/internal/dtype_util.py:248)   dt = base_dtype(dtype or dtype_hint or tensor_or_dtype)
    [249](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/internal/dtype_util.py:249) else:
    [250](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/internal/dtype_util.py:250)   # If this is a Python object, call `convert_to_tensor` and grab the dtype.
    [251](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/internal/dtype_util.py:251)   # Note that this will add ops in graph-mode; we may want to consider
    [252](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/internal/dtype_util.py:252)   # other ways to handle this case.

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/__init__.py:400, in __getattr__(attr)
    [397](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/__init__.py:397)     raise AttributeError(__former_attrs__[attr], name=None)
    [399](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/__init__.py:399) if attr in __expired_attributes__:
--> [400](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/__init__.py:400)     raise AttributeError(
    [401](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/__init__.py:401)         f"`np.{attr}` was removed in the NumPy 2.0 release. "
    [402](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/__init__.py:402)         f"{__expired_attributes__[attr]}",
    [403](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/__init__.py:403)         name=None
    [404](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/__init__.py:404)     )
    [406](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/__init__.py:406) if attr == "chararray":
    [407](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/__init__.py:407)     warnings.warn(
    [408](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/__init__.py:408)         "`np.chararray` is deprecated and will be removed from "
    [409](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/__init__.py:409)         "the main namespace in the future. Use an array with a string "
    [410](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/__init__.py:410)         "or bytes dtype instead.", DeprecationWarning, stacklevel=2)

AttributeError: `np.issctype` was removed in the NumPy 2.0 release. Use `issubclass(rep, np.generic)` instead.

When I manually ensure that numpy 1.26.4 is used instead, I get an error even earlier, right when I try to import dynamax

RecursionError                            Traceback (most recent call last)
Cell In[1], [line 2](vscode-notebook-cell:?execution_count=1&line=2)
      [1](vscode-notebook-cell:?execution_count=1&line=1) try:
----> [2](vscode-notebook-cell:?execution_count=1&line=2)     import dynamax
      [3](vscode-notebook-cell:?execution_count=1&line=3) except ModuleNotFoundError:
      [4](vscode-notebook-cell:?execution_count=1&line=4)     print('installing dynamax')

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/__init__.py:8
      [5](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/__init__.py:5) import dynamax.warnings
      [7](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/__init__.py:7) # Default to float32 matrix multiplication on TPUs and GPUs
----> [8](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/__init__.py:8) import jax
      [9](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/dynamax/__init__.py:9) jax.config.update('jax_default_matmul_precision', 'float32')

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/jax/__init__.py:25
     [22](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/jax/__init__.py:22) from jax.version import __version_info__ as __version_info__
     [24](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/jax/__init__.py:24) # Set Cloud TPU env vars if necessary before transitively loading C++ backend
---> [25](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/jax/__init__.py:25) from jax._src.cloud_tpu_init import cloud_tpu_init as _cloud_tpu_init
     [26](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/jax/__init__.py:26) try:
     [27](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/jax/__init__.py:27)   _cloud_tpu_init()

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/jax/_src/cloud_tpu_init.py:17
     [15](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/jax/_src/cloud_tpu_init.py:15) import os
     [16](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/jax/_src/cloud_tpu_init.py:16) from jax import version
---> [17](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/jax/_src/cloud_tpu_init.py:17) from jax._src import config
     [18](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/jax/_src/cloud_tpu_init.py:18) from jax._src import hardware_utils
     [20](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/jax/_src/cloud_tpu_init.py:20) running_in_cloud_tpu_vm: bool = False

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/jax/_src/config.py:27
     [24](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/jax/_src/config.py:24) import threading
     [25](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/jax/_src/config.py:25) from typing import Any, Generic, NamedTuple, NoReturn, Protocol, TypeVar, cast
---> [27](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/jax/_src/config.py:27) from jax._src import lib
     [28](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/jax/_src/config.py:28) from jax._src.lib import jax_jit
     [29](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/jax/_src/config.py:29) from jax._src.lib import transfer_guard_lib

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/jax/_src/lib/__init__.py:87
     [84](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/jax/_src/lib/__init__.py:84) cpu_feature_guard.check_cpu_features()
     [86](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/jax/_src/lib/__init__.py:86) import jaxlib.utils as utils
---> [87](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/jax/_src/lib/__init__.py:87) import jaxlib.xla_client as xla_client
     [88](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/jax/_src/lib/__init__.py:88) import jaxlib.lapack as lapack
     [90](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/jax/_src/lib/__init__.py:90) xla_extension = xla_client._xla

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/jaxlib/xla_client.py:30
     [27](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/jaxlib/xla_client.py:27) import threading
     [28](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/jaxlib/xla_client.py:28) from typing import Any, Protocol, Union
---> [30](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/jaxlib/xla_client.py:30) import ml_dtypes
     [31](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/jaxlib/xla_client.py:31) import numpy as np
     [33](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/jaxlib/xla_client.py:33) from . import xla_extension as _xla

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/ml_dtypes/__init__.py:32
     [16](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/ml_dtypes/__init__.py:16) __all__ = [
     [17](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/ml_dtypes/__init__.py:17)     "__version__",
     [18](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/ml_dtypes/__init__.py:18)     "bfloat16",
   (...)
     [27](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/ml_dtypes/__init__.py:27)     "uint4",
     [28](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/ml_dtypes/__init__.py:28) ]
     [30](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/ml_dtypes/__init__.py:30) from typing import Type
---> [32](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/ml_dtypes/__init__.py:32) from ml_dtypes._finfo import finfo
     [33](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/ml_dtypes/__init__.py:33) from ml_dtypes._iinfo import iinfo
     [34](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/ml_dtypes/__init__.py:34) from ml_dtypes._ml_dtypes_ext import bfloat16

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/ml_dtypes/_finfo.py:27
     [24](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/ml_dtypes/_finfo.py:24) from ml_dtypes._ml_dtypes_ext import float8_e5m2fnuz
     [25](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/ml_dtypes/_finfo.py:25) import numpy as np
---> [27](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/ml_dtypes/_finfo.py:27) _bfloat16_dtype = np.dtype(bfloat16)
     [28](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/ml_dtypes/_finfo.py:28) _float8_e4m3b11fnuz_dtype = np.dtype(float8_e4m3b11fnuz)
     [29](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/ml_dtypes/_finfo.py:29) _float8_e4m3fn_dtype = np.dtype(float8_e4m3fn)

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:46, in __repr__(dtype)
     [45](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:45) def __repr__(dtype):
---> [46](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:46)     arg_str = _construction_repr(dtype, include_align=False)
     [47](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:47)     if dtype.isalignedstruct:
     [48](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:48)         arg_str = arg_str + ", align=True"

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:100, in _construction_repr(dtype, include_align, short)
     [98](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:98)     return _subarray_str(dtype)
     [99](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:99) else:
--> [100](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:100)     return _scalar_str(dtype, short=short)

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:143, in _scalar_str(dtype, short)
    [140](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:140) elif dtype.type == np.timedelta64:
    [141](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:141)     return "'%sm8%s'" % (byteorder, _datetime_metadata_str(dtype))
--> [143](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:143) elif np.issubdtype(dtype, np.number):
    [144](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:144)     # Short repr with endianness, like '<f8'
    [145](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:145)     if short or dtype.byteorder not in ('=', '|'):
    [146](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:146)         return "'%s%c%d'" % (byteorder, dtype.kind, dtype.itemsize)

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/numerictypes.py:417, in issubdtype(arg1, arg2)
    [359](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/numerictypes.py:359) r"""
    [360](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/numerictypes.py:360) Returns True if first argument is a typecode lower/equal in type hierarchy.
    [361](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/numerictypes.py:361) 
   (...)
    [414](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/numerictypes.py:414) 
    [415](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/numerictypes.py:415) """
    [416](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/numerictypes.py:416) if not issubclass_(arg1, generic):
--> [417](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/numerictypes.py:417)     arg1 = dtype(arg1).type
    [418](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/numerictypes.py:418) if not issubclass_(arg2, generic):
    [419](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/numerictypes.py:419)     arg2 = dtype(arg2).type

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:46, in __repr__(dtype)
     [45](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:45) def __repr__(dtype):
---> [46](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:46)     arg_str = _construction_repr(dtype, include_align=False)
     [47](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:47)     if dtype.isalignedstruct:
     [48](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:48)         arg_str = arg_str + ", align=True"

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:100, in _construction_repr(dtype, include_align, short)
     [98](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:98)     return _subarray_str(dtype)
     [99](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:99) else:
--> [100](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:100)     return _scalar_str(dtype, short=short)

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:143, in _scalar_str(dtype, short)
    [140](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:140) elif dtype.type == np.timedelta64:
    [141](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:141)     return "'%sm8%s'" % (byteorder, _datetime_metadata_str(dtype))
--> [143](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:143) elif np.issubdtype(dtype, np.number):
    [144](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:144)     # Short repr with endianness, like '<f8'
    [145](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:145)     if short or dtype.byteorder not in ('=', '|'):
    [146](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:146)         return "'%s%c%d'" % (byteorder, dtype.kind, dtype.itemsize)

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/numerictypes.py:417, in issubdtype(arg1, arg2)
    [359](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/numerictypes.py:359) r"""
    [360](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/numerictypes.py:360) Returns True if first argument is a typecode lower/equal in type hierarchy.
    [361](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/numerictypes.py:361) 
   (...)
    [414](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/numerictypes.py:414) 
    [415](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/numerictypes.py:415) """
    [416](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/numerictypes.py:416) if not issubclass_(arg1, generic):
--> [417](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/numerictypes.py:417)     arg1 = dtype(arg1).type
    [418](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/numerictypes.py:418) if not issubclass_(arg2, generic):
    [419](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/numerictypes.py:419)     arg2 = dtype(arg2).type

    [... skipping similar frames: __repr__ at line 46 (726 times), _construction_repr at line 100 (726 times), _scalar_str at line 143 (725 times), issubdtype at line 417 (725 times)]

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:143, in _scalar_str(dtype, short)
    [140](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:140) elif dtype.type == np.timedelta64:
    [141](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:141)     return "'%sm8%s'" % (byteorder, _datetime_metadata_str(dtype))
--> [143](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:143) elif np.issubdtype(dtype, np.number):
    [144](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:144)     # Short repr with endianness, like '<f8'
    [145](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:145)     if short or dtype.byteorder not in ('=', '|'):
    [146](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:146)         return "'%s%c%d'" % (byteorder, dtype.kind, dtype.itemsize)

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/numerictypes.py:417, in issubdtype(arg1, arg2)
    [359](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/numerictypes.py:359) r"""
    [360](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/numerictypes.py:360) Returns True if first argument is a typecode lower/equal in type hierarchy.
    [361](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/numerictypes.py:361) 
   (...)
    [414](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/numerictypes.py:414) 
    [415](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/numerictypes.py:415) """
    [416](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/numerictypes.py:416) if not issubclass_(arg1, generic):
--> [417](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/numerictypes.py:417)     arg1 = dtype(arg1).type
    [418](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/numerictypes.py:418) if not issubclass_(arg2, generic):
    [419](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/numerictypes.py:419)     arg2 = dtype(arg2).type

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:46, in __repr__(dtype)
     [45](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:45) def __repr__(dtype):
---> [46](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:46)     arg_str = _construction_repr(dtype, include_align=False)
     [47](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:47)     if dtype.isalignedstruct:
     [48](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:48)         arg_str = arg_str + ", align=True"

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:100, in _construction_repr(dtype, include_align, short)
     [98](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:98)     return _subarray_str(dtype)
     [99](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:99) else:
--> [100](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:100)     return _scalar_str(dtype, short=short)

File ~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:104, in _scalar_str(dtype, short)
    [103](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:103) def _scalar_str(dtype, short):
--> [104](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:104)     byteorder = _byte_order_str(dtype)
    [106](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:106)     if dtype.type == np.bool_:
    [107](https://file+.vscode-resource.vscode-cdn.net/home/gdalle/Work/Review/~/miniforge3/envs/dynamax/lib/python3.12/site-packages/numpy/core/_dtype.py:107)         if short:

RecursionError: maximum recursion depth exceeded

Related issues:

@gdalle
Copy link
Author

gdalle commented Nov 1, 2024

I tried again to run the code on the documentation home page and got the following error:

import jax.numpy as jnp
import jax.random as jr
from dynamax.hidden_markov_model import GaussianHMM

key1, key2, key3 = jr.split(jr.PRNGKey(0), 3)
num_states = 3
emission_dim = 2
num_timesteps = 1000

# Make a Gaussian HMM and sample data from it
hmm = GaussianHMM(num_states, emission_dim)
true_params, _ = hmm.initialize(key1)
Traceback (most recent call last):
  File "/home/gdalle/Work/GitHub/Python/dynamax-test/test.py", line 12, in <module>
    true_params, _ = hmm.initialize(key1)
                     ^^^^^^^^^^^^^^^^^^^^
  File "/home/gdalle/Work/GitHub/Python/dynamax-test/.venv/lib/python3.12/site-packages/dynamax/hidden_markov_model/models/gaussian_hmm.py", line 649, in initialize
    params["initial"], props["initial"] = self.initial_component.initialize(key1, method=method, initial_probs=initial_probs)
                                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gdalle/Work/GitHub/Python/dynamax-test/.venv/lib/python3.12/site-packages/dynamax/hidden_markov_model/models/initial.py", line 45, in initialize
    initial_probs = tfd.Dirichlet(self.initial_probs_concentration).sample(seed=this_key)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gdalle/Work/GitHub/Python/dynamax-test/.venv/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py", line 1205, in sample
    return self._call_sample_n(sample_shape, seed, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gdalle/Work/GitHub/Python/dynamax-test/.venv/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py", line 1182, in _call_sample_n
    samples = self._sample_n(
              ^^^^^^^^^^^^^^^
  File "/home/gdalle/Work/GitHub/Python/dynamax-test/.venv/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/dirichlet.py", line 233, in _sample_n
    log_gamma_sample = gamma_lib.random_gamma(
                       ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gdalle/Work/GitHub/Python/dynamax-test/.venv/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/gamma.py", line 725, in random_gamma
    return random_gamma_with_runtime(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gdalle/Work/GitHub/Python/dynamax-test/.venv/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/distributions/gamma.py", line 718, in random_gamma_with_runtime
    seed = samplers.sanitize_seed(seed, salt='random_gamma')
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/gdalle/Work/GitHub/Python/dynamax-test/.venv/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/internal/samplers.py", line 144, in sanitize_seed
    seed = fold_in(seed, salt)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/gdalle/Work/GitHub/Python/dynamax-test/.venv/lib/python3.12/site-packages/tensorflow_probability/substrates/jax/internal/samplers.py", line 186, in fold_in
    seed, jnp.asarray(salt & np.uint32(2**32 - 1), dtype=SEED_DTYPE))
                      ~~~~~^~~~~~~~~~~~~~~~~~~~~~
OverflowError: Python int too large to convert to C long

I think it is because you haven't yet released the fixes and tagged v0.1.5.

@slinderman
Copy link
Collaborator

Hi @gdalle, I just pushed 0.1.5 to pypi and tagged it as the latest release. It has the numpy < 2.0 requirement and the test fixes. Sorry for missing this earlier!

@gdalle
Copy link
Author

gdalle commented Nov 8, 2024

Ok I finally managed to run the example code without issue on 0.1.5!

Here are a few remarks (keeping in mind that I don't usually code in Python):

  • What is the percentage of source code covered by your test suite? It should be easy to figure out with tools such as https://github.com/nedbat/coveragepy, and it could help you detect untested code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants