diff --git a/act_recog/config/defaults.py b/act_recog/config/defaults.py index 01f17cb..3c6c506 100644 --- a/act_recog/config/defaults.py +++ b/act_recog/config/defaults.py @@ -2,6 +2,8 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. """Configs.""" +import os +import pathlib from fvcore.common.config import CfgNode _C = CfgNode() @@ -46,7 +48,7 @@ def get_cfg(): """ Get a copy of the default config. """ - return _C + return _C.clone() def load_config(args): """ @@ -58,10 +60,34 @@ def load_config(args): # Setup cfg. cfg = get_cfg() # Load config from cfg. + if isinstance(args, (str, pathlib.Path)): + args = args_hook(args) if args.cfg_file is not None: - cfg.merge_from_file(args.cfg_file) + cfg.merge_from_file(find_config_file(args.cfg_file)) # Load config from command line, overwrite config from opts. if args.opts is not None: cfg.merge_from_list(args.opts) return cfg + +# get built-in configs from the step_recog/config directory +CONFIG_DIR = pathlib.Path(__file__).parent.parent.parent / 'config' + +def find_config_file(cfg_file): + cfg_files = [ + cfg_file, # you passed a valid config file path + CONFIG_DIR / cfg_file, # a path relative to the config directory + CONFIG_DIR / f'{cfg_file}.yaml', # the name without the extension + CONFIG_DIR / f'{cfg_file}.yml', + ] + for f in cfg_files: + if os.path.isfile(f): + return f + raise FileNotFoundError(cfg_file) + + +def args_hook(cfg_file): + args = lambda: None + args.cfg_file = cfg_file + args.opts = None + return args diff --git a/step_recog/config/defaults.py b/step_recog/config/defaults.py index 0125e7b..01e7d29 100644 --- a/step_recog/config/defaults.py +++ b/step_recog/config/defaults.py @@ -2,8 +2,11 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. """Configs.""" +import os +import pathlib from fvcore.common.config import CfgNode + _C = CfgNode() # ----------------------------------------------------------------------------- @@ -51,9 +54,12 @@ _C.MODEL.YOLO_CHECKPOINT_URL = '' _C.MODEL.OMNIGRU_CHECKPOINT_URL = '' +_C.MODEL.PRETRAINED_CHECKPOINT_URL = '' _C.MODEL.OMNIVORE_CONFIG = 'OMNIVORE' _C.MODEL.SLOWFAST_CONFIG = 'SLOWFAST' +_C.MODEL.VARIANTS = CfgNode(new_allowed=True) + # ----------------------------------------------------------------------------- # Dataset options # ----------------------------------------------------------------------------- @@ -94,7 +100,7 @@ def get_cfg(): """ Get a copy of the default config. """ - return _C + return _C.clone() def load_config(args): """ @@ -106,10 +112,37 @@ def load_config(args): # Setup cfg. cfg = get_cfg() # Load config from cfg. + if isinstance(args, (str, pathlib.Path)): + args = args_hook(args) if args.cfg_file is not None: - cfg.merge_from_file(args.cfg_file) + #cfg.merge_from_file(args.cfg_file) + cfg.merge_from_file(find_config_file(args.cfg_file)) # Load config from command line, overwrite config from opts. if args.opts is not None: cfg.merge_from_list(args.opts) - return cfg \ No newline at end of file + return cfg + +# get built-in configs from the step_recog/config directory +CONFIG_DIR = pathlib.Path(__file__).parent.parent.parent / 'config' + + +def find_config_file(cfg_file): + cfg_files = [ + cfg_file, # you passed a valid config file path + CONFIG_DIR / cfg_file, # a path relative to the config directory + CONFIG_DIR / f'{cfg_file}.yaml', # the name without the extension + CONFIG_DIR / f'{cfg_file}.yml', + ] + for f in cfg_files: + if os.path.isfile(f): + return f + raise FileNotFoundError(cfg_file) + + +def args_hook(cfg_file): + args = lambda: None + args.cfg_file = cfg_file + args.opts = None + return args + diff --git a/step_recog/full/model.py b/step_recog/full/model.py index 8238062..1462826 100644 --- a/step_recog/full/model.py +++ b/step_recog/full/model.py @@ -1,5 +1,6 @@ import numpy as np import torch +import functools from torch import nn from collections import deque from ultralytics import YOLO @@ -31,14 +32,22 @@ def build_model(cfg_file, fps): return MODEL_CLASS(cfg_file, fps).to("cuda") + +@functools.lru_cache(1) +def get_omnivore(cfg_fname): + omni_cfg = act_load_config(args_hook(cfg_fname)) + omnivore = Omnivore(omni_cfg, resize = False) + return omnivore, omni_cfg + + class StepPredictor(nn.Module): """Step prediction model that takes in frames and outputs step probabilities. """ def __init__(self, cfg_file, video_fps = 30): super().__init__() - # load config self._device = nn.Parameter(torch.empty(0)) - self.cfg = load_config(args_hook(cfg_file)) + # load config + self.cfg = load_config(args_hook(cfg_file)).clone() # clone prob not necessary but tinfoil # assign vocabulary self.STEPS = np.array([ @@ -74,34 +83,60 @@ def forward(self, image, queue_frame = True): class StepPredictor_GRU(StepPredictor): def __init__(self, cfg_file, video_fps = 30): super().__init__(cfg_file, video_fps) - self.omni_cfg = act_load_config(args_hook(self.cfg.MODEL.OMNIVORE_CONFIG)) +# self.omni_cfg = act_load_config(args_hook(self.cfg.MODEL.OMNIVORE_CONFIG)) self.MAX_OBJECTS = 25 - self.transform = transforms.Compose([ - transforms.Resize(self.omni_cfg.MODEL.IN_SIZE), - transforms.CenterCrop(self.omni_cfg.MODEL.IN_SIZE) - ]) +# self.transform = transforms.Compose([ +# transforms.Resize(self.omni_cfg.MODEL.IN_SIZE), +# transforms.CenterCrop(self.omni_cfg.MODEL.IN_SIZE) +# ]) # build model self.head = OmniGRU(self.cfg, load=True) self.head.eval() + frame_queue_len = 1 if self.cfg.MODEL.USE_ACTION: - self.omnivore = Omnivore(self.omni_cfg, resize = False) + omnivore, omni_cfg = get_omnivore(self.cfg.MODEL.OMNIVORE_CONFIG) + self.omnivore = omnivore + self.omni_cfg = omni_cfg + frame_queue_len = self.omni_cfg.DATASET.FPS * self.omni_cfg.MODEL.WIN_LENGTH + frame_queue_len = video_fps * self.omni_cfg.MODEL.WIN_LENGTH #default: 2seconds + self.transform = transforms.Compose([ + transforms.Resize(self.omni_cfg.MODEL.IN_SIZE), + transforms.CenterCrop(self.omni_cfg.MODEL.IN_SIZE) + ]) + #self.omnivore = Omnivore(self.omni_cfg, resize = False) if self.cfg.MODEL.USE_OBJECTS: yolo_checkpoint = cached_download_file(self.cfg.MODEL.YOLO_CHECKPOINT_URL) self.yolo = YOLO(yolo_checkpoint) self.yolo.eval = lambda *a: None self.clip_patches = ClipPatches(utils.clip_download_root) self.clip_patches.eval() + names = self.yolo.names + self.OBJECT_LABELS = np.array([str(names.get(i, i)) for i in range(len(names))]) + else: + self.OBJECT_LABELS = np.array([], dtype=str) if self.cfg.MODEL.USE_AUDIO: raise NotImplementedError() # frame buffers and model state - self.create_queue(video_fps * self.omni_cfg.MODEL.WIN_LENGTH) #default: 2seconds + self.frame_queue_len = frame_queue_len + self.create_queue(frame_queue_len) #default: 2seconds self.h = None + + def eval(self): + y=self.yolo + self.yolo = None + super().eval() + self.head.eval() + self.omnivore.eval() + self.yolo=y + return self + def reset(self): - super().__init__() + #super().__init__() + super().reset() self.h = None def queue_frame(self, image): @@ -115,7 +150,7 @@ def queue_frame(self, image): def prepare(self, im): return self.transform(Image.fromarray(im)) - def forward(self, image, queue_frame = True): + def forward(self, image, queue_frame = True, return_objects=False): # compute yolo Z_objects, Z_frame = torch.zeros((1, 1, 25, 0)).float(), torch.zeros((1, 1, 1, 0)).float() if self.cfg.MODEL.USE_OBJECTS: @@ -145,6 +180,7 @@ def forward(self, image, queue_frame = True): self.queue_frame(image) # compute omnivore embeddings + # [1, 32, 3, H, W] X_omnivore = torch.stack(list(self.input_queue), dim=1)[None] frame_idx = np.linspace(0, self.input_queue.maxlen - 1, self.omni_cfg.MODEL.NFRAMES).astype('long') #same as act_recog.dataset.milly.py:pack_frames_to_video_clip X_omnivore = X_omnivore[:, :, frame_idx, :, :] @@ -154,9 +190,19 @@ def forward(self, image, queue_frame = True): # mix it all together if self.h is None: self.h = self.head.init_hidden(Z_action.shape[0]) - - prob_step, self.h = self.head(Z_action.to(self._device.device), self.h.float(), Z_audio.to(self._device.device), Z_objects.to(self._device.device), Z_frame.to(self._device.device)) + + device = self._device.device + prob_step, self.h = self.head( + Z_action.to(device), + self.h.float(), + Z_audio.to(device), + Z_objects.to(device), + Z_frame.to(device)) + prob_step = torch.softmax(prob_step[..., :-2].detach(), dim=-1) #prob_step has <1 no step position> <2 begin-end frame identifiers> + + if return_objects: + return prob_step, results return prob_step class StepPredictor_Transformer(StepPredictor): @@ -179,4 +225,4 @@ def forward(self, image, queue_frame = True): prob_step = self.head(image.to(self._device.device), self.steps_feat.to(self._device.device)) prob_step = torch.softmax(prob_step.detach(), dim = -1) - return prob_step \ No newline at end of file + return prob_step diff --git a/step_recog/models.py b/step_recog/models.py index a2f30f6..b2e2b6a 100644 --- a/step_recog/models.py +++ b/step_recog/models.py @@ -4,6 +4,7 @@ import torch from collections import OrderedDict +from step_recog.full.download import cached_download_file device = 'cuda' if torch.cuda.is_available() else 'cpu' @@ -59,7 +60,8 @@ def __init__(self, cfg, load = False): self.relu = torch.nn.ReLU() if load: - self.load_state_dict( self.update_version(torch.load( cfg.MODEL.OMNIGRU_CHECKPOINT_URL ))) + f = cfg.MODEL.OMNIGRU_CHECKPOINT_URL or cached_download_file(cfg.MODEL.PRETRAINED_CHECKPOINT_URL) + self.load_state_dict(self.update_version(torch.load(f))) else: self.apply(custom_weights)