Skip to content

Commit

Permalink
Merge pull request fabiofelix#46 from fabiofelix/updates
Browse files Browse the repository at this point in the history
Fixed sound embedding extraction
  • Loading branch information
fabiofelix authored Jun 5, 2024
2 parents c904d02 + 27fc544 commit 013cc68
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 51 deletions.
1 change: 1 addition & 0 deletions step_recog/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
# Dataset options
# -----------------------------------------------------------------------------
_C.DATASET = CfgNode()
_C.DATASET.CLASS = 'Milly_multifeature_v4'
_C.DATASET.NAME = ''
_C.DATASET.LOCATION = ''
_C.DATASET.AUDIO_LOCATION = ''
Expand Down
65 changes: 44 additions & 21 deletions step_recog/datasets/milly.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import os
import torch
import tqdm
import numpy as np
import numpy as np, numba
import pandas as pd
import copy
import glob
import ipdb
import cv2
Expand Down Expand Up @@ -110,38 +109,61 @@ def collate_fn(data):

##TODO: It's returning the whole video
class Milly_multifeature(torch.utils.data.Dataset):

def __init__(self, cfg, split='train', filter=None):
self.cfg = cfg
self.data_filter = filter
self.cfg = cfg
self.data_filter = filter

if split == 'train':
self.annotations_file = cfg.DATASET.TR_ANNOTATIONS_FILE
elif split == 'validation':
self.annotations_file = cfg.DATASET.TR_ANNOTATIONS_FILE if cfg.DATASET.VL_ANNOTATIONS_FILE == '' else cfg.DATASET.VL_ANNOTATIONS_FILE
elif split == 'test':
self.annotations_file = cfg.DATASET.VL_ANNOTATIONS_FILE if cfg.DATASET.TS_ANNOTATIONS_FILE == '' else cfg.DATASET.TS_ANNOTATIONS_FILE
if split == 'train':
self.annotations_file = cfg.DATASET.TR_ANNOTATIONS_FILE
elif split == 'validation':
self.annotations_file = cfg.DATASET.TR_ANNOTATIONS_FILE if cfg.DATASET.VL_ANNOTATIONS_FILE == '' else cfg.DATASET.VL_ANNOTATIONS_FILE
elif split == 'test':
self.annotations_file = cfg.DATASET.VL_ANNOTATIONS_FILE if cfg.DATASET.TS_ANNOTATIONS_FILE == '' else cfg.DATASET.TS_ANNOTATIONS_FILE

self.image_augs = cfg.DATASET.INCLUDE_IMAGE_AUGMENTATIONS if split == 'train' else False
self.time_augs = cfg.DATASET.INCLUDE_TIME_AUGMENTATIONS if split == 'train' else False
self.image_augs = cfg.DATASET.INCLUDE_IMAGE_AUGMENTATIONS if split == 'train' else False
self.time_augs = cfg.DATASET.INCLUDE_TIME_AUGMENTATIONS if split == 'train' else False

self.rng = np.random.default_rng()
self._construct_loader(split)
self.rng = np.random.default_rng()
self._construct_loader(split)

def _construct_loader(self, split):
self.datapoints = {}
self.class_histogram = []
pass
self.overlap_summary = {}

def __len__(self):
return len(self.datapoints)
return len(self.datapoints)

import sys
from collections import deque

#to work with: torch.multiprocessing.set_start_method('spawn')
##https://stackoverflow.com/questions/44131691/how-to-clear-cache-or-force-recompilation-in-numba
##https://numba.pydata.org/numba-doc/0.48.0/developer/caching.html#cache-clearing
##https://numba.pydata.org/numba-doc/0.48.0/reference/envvars.html#envvar-NUMBA_CACHE_DIR
#to save numba cache out the /home folder
main_cache_path = os.path.join("/vast", os.path.basename(os.path.expanduser("~")))
clip_download_root = None
omni_path = os.path.join(os.path.expanduser("~"), ".cache/torch/hub/facebookresearch_omnivore_main")
sys.path.append(omni_path)

if os.path.isdir(main_cache_path):
cache_path = os.path.join(main_cache_path, "cache")

if not os.path.isdir(cache_path):
os.mkdir(cache_path)

numba.config.CACHE_DIR = cache_path #default: ~/.cache
clip_download_root = os.path.join(cache_path, "clip") #default: ~/.cache/clip

cache_path = os.path.join(cache_path, "torch", "hub")

if not os.path.isdir(cache_path):
os.makedirs(cache_path)

torch.hub.set_dir(cache_path) #default: ~/.cache/torch/hub
omni_path = os.path.join(cache_path, "facebookresearch_omnivore_main")

#to work with: torch.multiprocessing.set_start_method('spawn')
sys.path.append(omni_path)

from ultralytics import YOLO
#from torch.quantization import quantize_dynamic
Expand Down Expand Up @@ -174,6 +196,7 @@ def slowfast_hook(module, input, output):
embedding = input[0]
batch_size, _, _, _ = embedding.shape
output = embedding.reshape(batch_size, -1)
global SOUND_FEATURES_LIST
SOUND_FEATURES_LIST.extend(output.cpu().detach().numpy())

class Milly_multifeature_v4(Milly_multifeature):
Expand All @@ -193,7 +216,7 @@ def __init__(self, cfg, split='train', filter=None):
self.yolo.eval = yolo_eval #to work with: torch.multiprocessing.set_start_method('spawn')
# self.yolo = quantize_dynamic(self.yolo, {torch.nn.Linear, torch.nn.Conv2d}, dtype=torch.qint8)

self.clip_patches = ClipPatches()
self.clip_patches = ClipPatches(download_root=clip_download_root)
self.clip_patches.eval()

if self.cfg.MODEL.USE_ACTION:
Expand Down Expand Up @@ -699,7 +722,7 @@ def __getitem__(self, index):
window_step_label = torch.from_numpy(np.array(window_step_label))
window_position_label = torch.from_numpy(np.array(window_position_label))
window_stop_frame = torch.from_numpy(np.array(window_stop_frame))
video_id = np.array([window["video_id"]])
video_id = np.array([video["video_id"]])


return video_act, video_obj, video_frame, video_sound, window_step_label, window_position_label, window_stop_frame, video_id
4 changes: 2 additions & 2 deletions step_recog/full/clip_patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ class ClipPatches(nn.Module):
TODO: this could be implemented with a "whole frame bounding box".
'''
def __init__(self):
def __init__(self, download_root = None):
super().__init__()
self.model, self.transform = clip.load("ViT-B/16", jit=False)
self.model, self.transform = clip.load("ViT-B/16", jit=False, download_root=download_root)
self._device = nn.Parameter(torch.empty(0))

def stack_patches(self, patches):
Expand Down
56 changes: 43 additions & 13 deletions step_recog/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pdb, ipdb
import json
import scipy
from sklearn.metrics import confusion_matrix, classification_report, balanced_accuracy_score, accuracy_score
from sklearn.metrics import confusion_matrix, classification_report, balanced_accuracy_score, accuracy_score, precision_score, recall_score
from matplotlib import pyplot as plt
import seaborn as sb, pandas as pd
import warnings
Expand All @@ -20,15 +20,18 @@ def build_model(cfg):

return model, device

def get_class_weight(class_histogram):
class_weight = np.array(class_histogram) / np.sum(class_histogram)
class_weight = np.divide(1.0, class_weight, where = class_weight != 0) #avoid zero-division

return class_weight / np.sum(class_weight) ## norm in [0, 1]

def build_losses(loader, cfg, device):
class_weight = None
class_weight_tensor = None

if cfg.TRAIN.USE_CLASS_WEIGHT:
class_weight = np.array(loader.dataset.class_histogram) / np.sum(loader.dataset.class_histogram)
class_weight = np.divide(1.0, class_weight, where = class_weight != 0) #avoid zero-division
class_weight = class_weight / np.sum(class_weight) ## norm in [0, 1]

class_weight = get_class_weight(loader.dataset.class_histogram)
print("|- Class weights", class_weight)

class_weight_tensor = torch.FloatTensor(class_weight).to(device)
Expand Down Expand Up @@ -275,6 +278,15 @@ def train(train_loader, val_loader, cfg):
best_val_acc = val_acc
torch.save(model.state_dict(), best_model_path)

##Saving validation metrics
classes = [ i for i in range(model.number_classes)]
classes_desc = [ "Step " + str(i + 1) for i in range(model.number_classes)]
classes_desc[-1] = "No step"
original_output = cfg.OUTPUT.LOCATION
cfg.OUTPUT.LOCATION = os.path.join(original_output, "validation" )
save_evaluation(val_targets, val_outputs, classes, cfg, label_order = classes_desc, class_weight = val_class_weight)
cfg.OUTPUT.LOCATION = original_output

save_current_state(cfg, model, history, epoch)

plot_history(history, cfg)
Expand All @@ -293,7 +305,7 @@ def evaluate(model, data_loader, cfg):
targets = []
_, _, class_weight = build_losses(data_loader, cfg, device)

for action, obj, frame, audio, label, _, _, frame_idx, videos in data_loader:
for action, obj, frame, audio, label, _, mask, frame_idx, videos in data_loader:
h = model.init_hidden(len(action))

out, _ = model(action.to(device).float(), h, audio.to(device).float(), obj.to(device).float(), frame.to(device).float(), return_last_step = False)
Expand All @@ -302,21 +314,21 @@ def evaluate(model, data_loader, cfg):
frame_idx = frame_idx.cpu().numpy()
torch.cuda.empty_cache()

for video_id, video_frames, frame_target, frame_pred in zip(videos, frame_idx, label, out):
for video_id, video_frames, frame_target, frame_pred, video_masks in zip(videos, frame_idx, label, out, mask):
aux_frame = []
aux_targets = []
aux_outputs = []

for frame, target, pred in zip(video_frames, frame_target, frame_pred):
if frame > 0: #it's equal to test the mask value, like in train_step
for frame, target, pred, mask in zip(video_frames, frame_target, frame_pred, video_masks):
if mask > 0: #it's equal to test the mask value, like in train_step
aux_frame.append(frame)
aux_targets.append(target)
aux_outputs.append(pred)

targets.append(target)
outputs.append(pred)

save_video_evaluation(video_id, aux_frame, aux_targets, aux_outputs, cfg)
save_video_evaluation(video_id, aux_frame, aux_targets, aux_outputs, cfg)

targets = np.array(targets)
outputs = np.array(outputs)
Expand Down Expand Up @@ -605,6 +617,16 @@ def save_video_evaluation(video_id, window_last_frame, expected, probs, cfg):
predicted[predicted == last_predicted] = -1

classes_desc = [ "No step" if i == 0 else "Step " + str(i) for i in range(last_expected + 1)]
accuracy = accuracy_score(expected, predicted)
acc_desc = "acc"

if cfg.TRAIN.USE_CLASS_WEIGHT:
acc_desc = "weighted acc"
class_weight = get_class_weight([ np.sum(expected == c) for c in np.unique(expected) ])
accuracy = weighted_accuracy(expected, predicted, class_weight=class_weight)

precision = precision_score(expected, predicted, average=None)
recall = recall_score(expected, predicted, average=None)

figure = plt.figure(figsize = (1024 / 100, 768 / 100), dpi = 100)

Expand All @@ -614,9 +636,17 @@ def save_video_evaluation(video_id, window_last_frame, expected, probs, cfg):
plt.yticks( [ i - 1 for i in range(last_expected + 1) ], classes_desc)

plt.step(window_last_frame, predicted, c="orange")
plt.yticks( [ i - 1 for i in range(last_predicted + 1) ], classes_desc)

plt.legend(["target", "predicted"])
plt.yticks( [ i - 1 for i in range(last_predicted + 1) ], classes_desc)

plt.plot(1, np.min([expected, predicted]), 'white')
plt.plot(1, np.min([expected, predicted]), 'white')
plt.plot(1, np.min([expected, predicted]), 'white')

plt.legend(["target", "predicted",
"{} {:.2f}".format(acc_desc, accuracy ),
"precision {:.2f}+/-{:.2f}".format( precision.mean(), precision.std() ),
"recall {:.2f}+/-{:.2f}".format( recall.mean(), recall.std() )
])
plt.grid(axis = "y")

probs = np.max(probs, axis = 1)
Expand Down
31 changes: 16 additions & 15 deletions tools/run_step_recog.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import time
from torch.utils.data import DataLoader
from step_recog.config import load_config
from step_recog import train, evaluate, build_model
from step_recog.datasets import Milly_multifeature_v4, collate_fn
from step_recog import datasets, train, evaluate, build_model
from sklearn.model_selection import KFold, train_test_split
import pandas as pd, pdb, numpy as np
import math
Expand Down Expand Up @@ -63,23 +62,24 @@ def main():
if cfg.TRAIN.USE_CROSS_VALIDATION:
train_kfold(cfg, args)
else:
train_kfold_step(cfg)
train_hold_out(cfg)
else:
DATASET_CLASS = getattr(datasets, cfg.DATASET.CLASS)
model, _ = build_model(cfg)
weights = torch.load(cfg.MODEL.OMNIGRU_CHECKPOINT_URL)
model.load_state_dict(model.update_version(weights))

data = pd.read_csv(cfg.DATASET.TS_ANNOTATIONS_FILE)
videos = data.video_id.unique()
_, video_test = my_train_test_split(cfg, videos)
ts_dataset = Milly_multifeature_v4(cfg, split='test', filter = video_test)
ts_dataset = DATASET_CLASS(cfg, split='test', filter = video_test)

ts_data_loader = DataLoader(
ts_dataset,
shuffle=False,
batch_size=cfg.TRAIN.BATCH_SIZE,
num_workers=min(math.ceil(len(ts_dataset) / cfg.TRAIN.BATCH_SIZE), cfg.DATALOADER.NUM_WORKERS),
collate_fn=collate_fn,
collate_fn=datasets.collate_fn,
drop_last=False,
timeout=timeout)

Expand Down Expand Up @@ -119,28 +119,29 @@ def train_kfold(cfg, args, k = 10):
video_train = videos[train_idx]
video_val = videos[val_idx]

train_kfold_step(cfg, os.path.join(main_path, "fold_{:02d}".format(idx) ), video_train, video_val, video_test)
train_hold_out(cfg, os.path.join(main_path, "fold_{:02d}".format(idx) ), video_train, video_val, video_test)

def train_kfold_step(cfg, main_path = None, video_train = None, video_val = None, video_test = None):
def train_hold_out(cfg, main_path = None, video_train = None, video_val = None, video_test = None):
DATASET_CLASS = getattr(datasets, cfg.DATASET.CLASS)
timeout = 0

tr_dataset = Milly_multifeature_v4(cfg, split='train', filter=video_train)
vl_dataset = Milly_multifeature_v4(cfg, split='validation', filter=video_val)
tr_dataset = DATASET_CLASS(cfg, split='train', filter=video_train)
vl_dataset = DATASET_CLASS(cfg, split='validation', filter=video_val)

tr_data_loader = DataLoader(
tr_dataset,
shuffle=False,
batch_size=cfg.TRAIN.BATCH_SIZE,
num_workers=min(math.ceil(len(tr_dataset) / cfg.TRAIN.BATCH_SIZE), cfg.DATALOADER.NUM_WORKERS),
collate_fn=collate_fn,
collate_fn=datasets.collate_fn,
drop_last=True,
timeout=timeout)
vl_data_loader = DataLoader(
vl_dataset,
shuffle=False,
batch_size=cfg.TRAIN.BATCH_SIZE,
num_workers=min(math.ceil(len(vl_dataset) / cfg.TRAIN.BATCH_SIZE), cfg.DATALOADER.NUM_WORKERS),
collate_fn=collate_fn,
collate_fn=datasets.collate_fn,
drop_last=False,
timeout=timeout)

Expand Down Expand Up @@ -168,19 +169,19 @@ def train_kfold_step(cfg, main_path = None, video_train = None, video_val = None
weights = torch.load(model_name)
model.load_state_dict(model.update_version(weights))

cfg.OUTPUT.LOCATION = val_path
evaluate(model, vl_data_loader, cfg)
## cfg.OUTPUT.LOCATION = val_path
## evaluate(model, vl_data_loader, cfg)

del vl_data_loader
del vl_dataset

ts_dataset = Milly_multifeature_v4(cfg, split='test', filter = video_test)
ts_dataset = DATASET_CLASS(cfg, split='test', filter = video_test)
ts_data_loader = DataLoader(
ts_dataset,
shuffle=False,
batch_size=cfg.TRAIN.BATCH_SIZE,
num_workers=min(math.ceil(len(ts_dataset) / cfg.TRAIN.BATCH_SIZE), cfg.DATALOADER.NUM_WORKERS),
collate_fn=collate_fn,
collate_fn=datasets.collate_fn,
drop_last=False,
timeout=timeout)

Expand Down

0 comments on commit 013cc68

Please sign in to comment.