Skip to content

Commit

Permalink
Add further type annotations to hmm initial base class
Browse files Browse the repository at this point in the history
Further, enforce that either key or initial_probs specified in
initialize method.
  • Loading branch information
gileshd committed Oct 16, 2024
1 parent bff1249 commit 9b6261d
Showing 1 changed file with 31 additions and 13 deletions.
44 changes: 31 additions & 13 deletions dynamax/hidden_markov_model/models/initial.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from dynamax.hidden_markov_model.models.abstractions import HMMInitialState
from dynamax.parameters import ParameterProperties
from typing import Any, cast, NamedTuple, Optional, Tuple, Union
import jax.numpy as jnp
import jax.random as jr
from jaxtyping import Float, Array
import tensorflow_probability.substrates.jax.distributions as tfd
import tensorflow_probability.substrates.jax.bijectors as tfb
from typing import NamedTuple, Union
from dynamax.hidden_markov_model.inference import HMMPosterior
from dynamax.hidden_markov_model.models.abstractions import HMMInitialState
from dynamax.parameters import ParameterProperties
from dynamax.types import Scalar


class ParamsStandardHMMInitialState(NamedTuple):
Expand All @@ -17,18 +19,23 @@ class StandardHMMInitialState(HMMInitialState):
"""
def __init__(self,
num_states: int,
initial_probs_concentration: Union[float, Float[Array, " num_states"]]=1.1):
initial_probs_concentration: Union[Scalar, Float[Array, " num_states"]]=1.1):
"""
Args:
initial_probabilities[k]: prob(hidden(1)=k)
"""
self.num_states = num_states
self.initial_probs_concentration = initial_probs_concentration * jnp.ones(num_states)

def distribution(self, params, inputs=None):
def distribution(self, params: ParamsStandardHMMInitialState, inputs=None) -> tfd.Distribution:
return tfd.Categorical(probs=params.probs)

def initialize(self, key=None, method="prior", initial_probs=None):
def initialize(
self,
key: Optional[Array]=None,
method="prior",
initial_probs: Optional[Float[Array, " num_states"]]=None
) -> Tuple[ParamsStandardHMMInitialState, ParamsStandardHMMInitialState]:
"""Initialize the model parameters and their corresponding properties.
Args:
Expand All @@ -41,27 +48,38 @@ def initialize(self, key=None, method="prior", initial_probs=None):
"""
# Initialize the initial probabilities
if initial_probs is None:
this_key, key = jr.split(key)
initial_probs = tfd.Dirichlet(self.initial_probs_concentration).sample(seed=this_key)
if key is None:
raise ValueError("key must be provided if initial_probs is not provided.")
else:
this_key, key = jr.split(key)
initial_probs = tfd.Dirichlet(self.initial_probs_concentration).sample(seed=this_key)

# Package the results into dictionaries
params = ParamsStandardHMMInitialState(probs=initial_probs)
props = ParamsStandardHMMInitialState(probs=ParameterProperties(constrainer=tfb.SoftmaxCentered()))
return params, props

def log_prior(self, params):
def log_prior(self, params: ParamsStandardHMMInitialState) -> Scalar:
return tfd.Dirichlet(self.initial_probs_concentration).log_prob(params.probs)

def _compute_initial_probs(self, params, inputs=None):
def _compute_initial_probs(
self, params: ParamsStandardHMMInitialState, inputs=None
) -> Float[Array, " num_states"]:
return params.probs

def collect_suff_stats(self, params, posterior, inputs=None):
def collect_suff_stats(self, params, posterior: HMMPosterior, inputs=None) -> Float[Array, " num_states"]:
return posterior.smoothed_probs[0]

def initialize_m_step_state(self, params, props):
def initialize_m_step_state(self, params, props) -> None:
return None

def m_step(self, params, props, batch_stats, m_step_state):
def m_step(
self,
params: ParamsStandardHMMInitialState,
props: ParamsStandardHMMInitialState,
batch_stats: Float[Array, "batch num_states"],
m_step_state: Any
) -> Tuple[ParamsStandardHMMInitialState, Any]:
if props.probs.trainable:
if self.num_states == 1:
probs = jnp.array([1.0])
Expand Down

0 comments on commit 9b6261d

Please sign in to comment.