diff --git a/dynamax/hidden_markov_model/models/linreg_hmm.py b/dynamax/hidden_markov_model/models/linreg_hmm.py index c41fd78c..c0a03d81 100644 --- a/dynamax/hidden_markov_model/models/linreg_hmm.py +++ b/dynamax/hidden_markov_model/models/linreg_hmm.py @@ -94,7 +94,7 @@ def distribution( self, params: ParamsLinearRegressionHMMEmissions, state: Union[int, Int[Array, ""]], - inputs: Array + inputs: Float[Array, " input_dim"] ): prediction = params.weights[state] @ inputs prediction += params.biases[state]