From 032871204ab088ce748e01951d388312af673e92 Mon Sep 17 00:00:00 2001 From: gileshd Date: Thu, 12 Sep 2024 16:04:39 +0100 Subject: [PATCH] Fix LinearGaussianSSM.sample type hint --- dynamax/linear_gaussian_ssm/models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/dynamax/linear_gaussian_ssm/models.py b/dynamax/linear_gaussian_ssm/models.py index 47ce6cce..4fe22d24 100644 --- a/dynamax/linear_gaussian_ssm/models.py +++ b/dynamax/linear_gaussian_ssm/models.py @@ -192,21 +192,21 @@ def emission_distribution( self, params: ParamsLGSSM, state: Float[Array, " state_dim"], - inputs: Optional[Float[Array, "ntime input_dim"]]=None + inputs: Optional[Float[Array, "ntime input_dim"]] = None ) -> tfd.Distribution: inputs = inputs if inputs is not None else jnp.zeros(self.input_dim) mean = params.emissions.weights @ state + params.emissions.input_weights @ inputs if self.has_emissions_bias: mean += params.emissions.bias return MVN(mean, params.emissions.cov) - + def sample( self, params: ParamsLGSSM, key: PRNGKeyT, num_timesteps: int, - inputs: Optional[Float[Array, "ntime input_dim"]] = None - ) -> PosteriorGSSMFiltered: + inputs: Optional[Float[Array, "num_timesteps input_dim"]] = None, + ) -> Tuple[Float[Array, "num_timesteps state_dim"], Float[Array, "num_timesteps emission_dim"]]: return lgssm_joint_sample(params, key, num_timesteps, inputs) def marginal_log_prob(