From cbf464049f3bb0bc7bf08f065bdbe00d1218ea7d Mon Sep 17 00:00:00 2001 From: ncguilbeault Date: Thu, 21 Nov 2024 15:43:59 +0000 Subject: [PATCH] Update name to predicted states --- .../Observations/GaussianObservationsStatistics.cs | 10 +++++----- src/Bonsai.ML.HiddenMarkovModels/main.py | 7 +++++-- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/Bonsai.ML.HiddenMarkovModels/Observations/GaussianObservationsStatistics.cs b/src/Bonsai.ML.HiddenMarkovModels/Observations/GaussianObservationsStatistics.cs index 86b84959..a1deec68 100644 --- a/src/Bonsai.ML.HiddenMarkovModels/Observations/GaussianObservationsStatistics.cs +++ b/src/Bonsai.ML.HiddenMarkovModels/Observations/GaussianObservationsStatistics.cs @@ -45,11 +45,11 @@ public class GaussianObservationsStatistics public double[,] BatchObservations { get; set; } /// - /// The sequence of inferred most probable states. + /// The predicted state for each observation in the batch of observations. /// - [Description("The sequence of inferred most probable states.")] + [Description("The predicted state for each observation in the batch of observations.")] [XmlIgnore] - public long[] InferredMostProbableStates { get; set; } + public long[] PredictedStates { get; set; } /// /// Transforms an observable sequence of into an observable sequence @@ -64,7 +64,7 @@ public IObservable Process(IObservable var covarianceMatricesPyObj = (double[,,])observationsPyObj.GetArrayAttr("Sigmas"); var stdDevsPyObj = DiagonalSqrt(covarianceMatricesPyObj); var batchObservationsPyObj = (double[,])pyObject.GetArrayAttr("batch_observations"); - var inferredMostProbableStatesPyObj = (long[])pyObject.GetArrayAttr("inferred_most_probable_states"); + var predictedStatesPyObj = (long[])pyObject.GetArrayAttr("predicted_states"); return new GaussianObservationsStatistics { @@ -72,7 +72,7 @@ public IObservable Process(IObservable StdDevs = stdDevsPyObj, CovarianceMatrices = covarianceMatricesPyObj, BatchObservations = batchObservationsPyObj, - InferredMostProbableStates = inferredMostProbableStatesPyObj + PredictedStates = PredictedStates }; }); } diff --git a/src/Bonsai.ML.HiddenMarkovModels/main.py b/src/Bonsai.ML.HiddenMarkovModels/main.py index aad1386e..3f83b932 100644 --- a/src/Bonsai.ML.HiddenMarkovModels/main.py +++ b/src/Bonsai.ML.HiddenMarkovModels/main.py @@ -85,7 +85,7 @@ def get_nonlinearity_type(func): self.thread = None self.curr_batch_size = 0 self.flush_data_between_batches = True - self.inferred_most_probable_states = np.array([], dtype=int) + self.predicted_states = np.array([], dtype=int) def update_params(self, initial_state_distribution, transitions_params, observations_params): hmm_params = self.params @@ -122,6 +122,9 @@ def update_params(self, initial_state_distribution, transitions_params, observat else: self.observations_params = (hmm_params[2],) + def get_predicted_states(self): + self.predicted_states = np.array([self.infer_state(obs) for obs in self.batch_observations]).astype(int) + def infer_state(self, observation: list[float]): self.log_alpha = self.compute_log_alpha( @@ -221,7 +224,7 @@ def on_completion(future): if self.flush_data_between_batches: self.batch = None - self.inferred_most_probable_states = np.array([self.infer_state(obs) for obs in self.batch_observations]).astype(int) + self.get_predicted_states() self.is_running = True