Skip to content

Commit

Permalink
Updated main script for maintaining batch data through observations
Browse files Browse the repository at this point in the history
  • Loading branch information
ncguilbeault committed Jan 7, 2025
1 parent cbf4640 commit 6c981fb
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/Bonsai.ML.HiddenMarkovModels/InferState.bonsai
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
</Expression>
<Expression xsi:type="Combinator">
<Combinator xsi:type="py:Exec">
<py:Script>hmm.most_likely_states([59.7382107943162,3.99285183724331])</py:Script>
<py:Script>hmm.infer_state([59.7382107943162,3.99285183724331])</py:Script>
</Combinator>
</Expression>
<Expression xsi:type="WorkflowOutput" />
Expand Down
16 changes: 6 additions & 10 deletions src/Bonsai.ML.HiddenMarkovModels/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,15 @@ def update_params(self, initial_state_distribution, transitions_params, observat
else:
self.observations_params = (hmm_params[2],)

def get_predicted_states(self):
self.predicted_states = np.array([self.infer_state(obs) for obs in self.batch_observations]).astype(int)

def infer_state(self, observation: list[float]):

self.log_alpha = self.compute_log_alpha(
np.expand_dims(np.array(observation), 0), self.log_alpha)
observation = np.expand_dims(np.array(observation), 0)
self.log_alpha = self.compute_log_alpha(observation, self.log_alpha)
self.state_probabilities = np.exp(self.log_alpha).astype(np.double)
return self.state_probabilities.argmax()
prediction = self.state_probabilities.argmax()
self.predicted_states = np.append(self.predicted_states, prediction)
self.batch_observations = np.vstack([self.batch_observations, observation])
return prediction

def compute_log_alpha(self, obs, log_alpha=None):

Expand Down Expand Up @@ -174,8 +174,6 @@ def fit_async(self,
self.batch = np.vstack(
[self.batch[1:], np.expand_dims(np.array(observation), 0)])

self.batch_observations = self.batch

if not self.is_running and self.loop is None and self.thread is None:

if self.curr_batch_size >= batch_size:
Expand Down Expand Up @@ -224,8 +222,6 @@ def on_completion(future):
if self.flush_data_between_batches:
self.batch = None

self.get_predicted_states()

self.is_running = True

if self.loop is None or self.loop.is_closed():
Expand Down

0 comments on commit 6c981fb

Please sign in to comment.