Skip to content

Commit

Permalink
Update name to predicted states
Browse files Browse the repository at this point in the history
  • Loading branch information
ncguilbeault committed Jan 7, 2025
1 parent 2447428 commit cbf4640
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ public class GaussianObservationsStatistics
public double[,] BatchObservations { get; set; }

/// <summary>
/// The sequence of inferred most probable states.
/// The predicted state for each observation in the batch of observations.
/// </summary>
[Description("The sequence of inferred most probable states.")]
[Description("The predicted state for each observation in the batch of observations.")]
[XmlIgnore]
public long[] InferredMostProbableStates { get; set; }
public long[] PredictedStates { get; set; }

/// <summary>
/// Transforms an observable sequence of <see cref="PyObject"/> into an observable sequence
Expand All @@ -64,15 +64,15 @@ public IObservable<GaussianObservationsStatistics> Process(IObservable<PyObject>
var covarianceMatricesPyObj = (double[,,])observationsPyObj.GetArrayAttr("Sigmas");
var stdDevsPyObj = DiagonalSqrt(covarianceMatricesPyObj);
var batchObservationsPyObj = (double[,])pyObject.GetArrayAttr("batch_observations");
var inferredMostProbableStatesPyObj = (long[])pyObject.GetArrayAttr("inferred_most_probable_states");
var predictedStatesPyObj = (long[])pyObject.GetArrayAttr("predicted_states");

return new GaussianObservationsStatistics
{
Means = meansPyObj,
StdDevs = stdDevsPyObj,
CovarianceMatrices = covarianceMatricesPyObj,
BatchObservations = batchObservationsPyObj,
InferredMostProbableStates = inferredMostProbableStatesPyObj
PredictedStates = PredictedStates
};
});
}
Expand Down
7 changes: 5 additions & 2 deletions src/Bonsai.ML.HiddenMarkovModels/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def get_nonlinearity_type(func):
self.thread = None
self.curr_batch_size = 0
self.flush_data_between_batches = True
self.inferred_most_probable_states = np.array([], dtype=int)
self.predicted_states = np.array([], dtype=int)

def update_params(self, initial_state_distribution, transitions_params, observations_params):
hmm_params = self.params
Expand Down Expand Up @@ -122,6 +122,9 @@ def update_params(self, initial_state_distribution, transitions_params, observat
else:
self.observations_params = (hmm_params[2],)

def get_predicted_states(self):
self.predicted_states = np.array([self.infer_state(obs) for obs in self.batch_observations]).astype(int)

def infer_state(self, observation: list[float]):

self.log_alpha = self.compute_log_alpha(
Expand Down Expand Up @@ -221,7 +224,7 @@ def on_completion(future):
if self.flush_data_between_batches:
self.batch = None

self.inferred_most_probable_states = np.array([self.infer_state(obs) for obs in self.batch_observations]).astype(int)
self.get_predicted_states()

self.is_running = True

Expand Down

0 comments on commit cbf4640

Please sign in to comment.