Skip to content

Commit

Permalink
Merge branch 'fabiofelix:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
fabiofelix authored Nov 15, 2024
2 parents a377ffa + 87889c9 commit 7ba0bef
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 28 deletions.
5 changes: 4 additions & 1 deletion step_recog/full/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ def build_model(cfg_file=None, fps=10, skill=None, variant=0, checkpoint_path=No
cfg_name = cfg_file.split(os.sep).rsplit('.',1)[0]
cfg.MODEL.OMNIGRU_CHECKPOINT_URL = f'/home/user/models/{cfg_name}.pt'

return MODEL_CLASS(cfg, fps).to("cuda")
model = MODEL_CLASS(cfg, fps).to("cuda")
model.eval()
return model


@functools.lru_cache(1)
def get_omnivore(cfg_fname):
Expand Down
62 changes: 43 additions & 19 deletions step_recog/full/statemachine.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
# Correcting indentation for docstrings and rerunning the module code with tests
import numpy as np
import random
from collections import deque
from enum import IntEnum

STATE_UNOBSERVED = 0
STATE_CURRENT = 1
STATE_DONE = 2
class State(IntEnum):
UNOBSERVED = 0
CURRENT = 1
DONE = 2

class SmoothType(IntEnum):
MEAN = 0
MEDIAN = 1
EXP = 2

class ProcedureStateMachine:
"""
Expand All @@ -14,18 +22,45 @@ class ProcedureStateMachine:
Attributes:
current_state (numpy.ndarray): The current state of each step in the procedure.
"""
def __init__(self, num_steps):
def __init__(self, num_steps, maxlen = 3):
"""
Initializes the ProcedureStateMachine with a given number of steps.
Args:
num_steps (int): The number of steps in the procedure (excluding the 'no step').
maxlen (int): Exponential moving average smoothes the predictions with a 'maxlen' window
It avoids noisy predictions changing the states.
Avoid using large 'maxlen' (e.g. >= 3)
"""
self.num_steps = num_steps
self.window_probs = deque(maxlen= 1 if maxlen is None else maxlen)
self.ema_alpha = 0.5
self.reset()

def reset(self):
self.current_state = np.zeros(self.num_steps, dtype=int) ## STATE_UNOBSERVED
self.window_probs.clear()

def smooth_probs(self, probabilities, type = SmoothType.MEDIAN):
self.window_probs.append(probabilities)
window_probs_aux = np.array(self.window_probs)
move_avg = window_probs_aux[0] if window_probs_aux.shape[0] == 1 else np.median(window_probs_aux, axis = 0)

if type == SmoothType.MEAN:
move_avg = window_probs_aux[0] if window_probs_aux.shape[0] == 1 else window_probs_aux.mean(axis = 0)
elif type == SmoothType.EXP:
move_avg = np.zeros(window_probs_aux.shape[1], dtype = window_probs_aux.dtype)

for step, probs in enumerate(window_probs_aux.T):
move_avg[step] = probs[0]

for pb in probs[1:]:
move_avg[step] = self.ema_alpha * pb + (1 - self.ema_alpha) * move_avg[step]

max_prob = np.max(move_avg)
max_indices = np.where(move_avg == max_prob)[0]

return max_indices

def process_timestep(self, probabilities):
"""
Expand All @@ -34,24 +69,13 @@ def process_timestep(self, probabilities):
Args:
probabilities (numpy.ndarray): Probabilities for each step including 'no step' at the last index.
"""
step_probabilities = probabilities[:-1]
max_prob = np.max(step_probabilities)
max_prob_no_step = probabilities[-1]

max_indices = np.where(step_probabilities == max_prob)[0]

## print("max_prob", max_prob, "max_prob_no_step", max_prob_no_step, "state", self.current_state)

if max_prob_no_step >= max_prob:
if np.sum(self.current_state == STATE_CURRENT) == 1 and np.sum(self.current_state == STATE_DONE) == len(self.current_state) - 1:
self.current_state[:] = STATE_DONE
return
max_indices = self.smooth_probs(probabilities)

chosen_index = random.choice(max_indices) if len(max_indices) > 1 else max_indices[0]

if self.current_state[chosen_index] != STATE_CURRENT:
self.current_state[self.current_state == STATE_CURRENT] = STATE_DONE
self.current_state[chosen_index] = STATE_CURRENT
if self.current_state[chosen_index] != State.CURRENT:
self.current_state[self.current_state == State.CURRENT] = State.DONE
self.current_state[chosen_index] = State.CURRENT

# Define the tests within the same environment
def run_tests():
Expand Down
18 changes: 10 additions & 8 deletions step_recog/full/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def main(video_path, output_path='output.mp4', cfg_file=""):
video_info = sv.VideoInfo.from_video_path(video_path)

# define model
model = build_model(cfg_file, video_info.fps)
psm = ProcedureStateMachine(model.cfg.MODEL.OUTPUT_DIM)
model = build_model(cfg_file, fps=video_info.fps)
psm = ProcedureStateMachine(model.cfg.MODEL.OUTPUT_DIM + 1)

step_process = video_info.fps #1 second by default
prob_step = np.zeros(model.cfg.MODEL.OUTPUT_DIM + 1)
Expand All @@ -34,21 +34,23 @@ def main(video_path, output_path='output.mp4', cfg_file=""):

with sv.VideoSink(output_path, video_info=video_info) as sink:
# iterate over video frames
for idx, frame in tqdm.tqdm(enumerate(sv.get_video_frames_generator(video_path))):
pbar = tqdm.tqdm(enumerate(sv.get_video_frames_generator(video_path)))
for idx, frame in pbar:
frame_aux = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frame_aux = model.prepare(frame_aux)
model.queue_frame(frame_aux)

if idx % step_process == 0:
# take in a queue frame and make the next prediction
prob_step = model(frame_aux, queue_frame = False).cpu().squeeze().numpy()
psm.process_timestep(prob_step)
step_idx = np.argmax(prob_step)
step_desc = "No step" if step_idx >= len(model.STEPS) else model.STEPS[step_idx]
psm.process_timestep(prob_step)
step_desc = "No step" if step_idx >= len(model.STEPS) else model.STEPS[step_idx]

pbar.set_description(" ".join(f"{x:.0%}" for x in prob_step) + " | " + " ".join(f'{x}' for x in psm.current_state))

# draw the prediction (could be your bar chart) on the frame
plot_graph(frame, prob_step, step_desc, psm.current_state)
plot_graph(frame, prob_step, step_desc, psm.current_state[:-1])
sink.write_frame(frame)

##TODO: Review the offsets
Expand Down Expand Up @@ -100,4 +102,4 @@ def plot_graph(frame, prob_step, step_desc, current_state, tl=(10, 25), scale=1.

if __name__ == '__main__':
import fire
fire.Fire(main)
fire.Fire(main)

0 comments on commit 7ba0bef

Please sign in to comment.