diff --git a/README.md b/README.md index c792138..16f99d2 100644 --- a/README.md +++ b/README.md @@ -131,8 +131,8 @@ python tools/test.py --cfg config/M3.yaml 2. Training/evaluation routines: `step_recog/iterators.py` (functions *train*, *evaluate*) 3. Model classes: `step_recog/models.py` 4. Dataloader: `step_recog/datasets/milly.py` (methods *_construct_loader* and *__getitem\__*) -- class *Milly_multifeature_v4* loads video frames and returns formated features -- class *Milly_multifeature_v5* loads (preprocessed) features and returns formated features +- class *Milly_multifeature_v4* loads video frames and returns features +- class *Milly_multifeature_v5* loads and returns (preprocessed) features 5. Image augmentation: `tools/augmentation.py` (function *get_augmentation*) 6. Basic configuration: `step_recog/config/defaults.py` (more important), `act_recog/config/defaults.py`, `auditory_slowfast/config/defaults.py` 6. Visualizer: `step_recog/full/visualize.py` implements a specific code that combines dataloading, model prediction, and a state machine. It uses the user interface with the trained models. diff --git a/act_recog/models/video_model_builder.py b/act_recog/models/video_model_builder.py index 3c07eb6..979236c 100644 --- a/act_recog/models/video_model_builder.py +++ b/act_recog/models/video_model_builder.py @@ -14,14 +14,19 @@ from torch.nn.init import normal_ from torch.utils import model_zoo from copy import deepcopy +from PIL import Image import pdb from .build import MODEL_REGISTRY from act_recog.datasets.transform import uniform_crop +from torchvision import transforms + +def max_norm(x): + return x / x.max() @MODEL_REGISTRY.register() class Omnivore(nn.Module): - def __init__(self, cfg): + def __init__(self, cfg, resize = True): super().__init__() # model @@ -32,6 +37,21 @@ def __init__(self, cfg): self.heads = self.model.heads self.model.heads = nn.Identity() + self.transform = transforms.Compose([ + transforms.ToPILImage(), + transforms.Resize(self.cfg.MODEL.IN_SIZE), + transforms.CenterCrop(self.cfg.MODEL.IN_SIZE), + transforms.ToTensor(), + transforms.Lambda(max_norm), + transforms.Normalize(mean=self.cfg.MODEL.MEAN, std=self.cfg.MODEL.STD), + ]) + + if not resize: + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Lambda(max_norm), + transforms.Normalize(mean=self.cfg.MODEL.MEAN, std=self.cfg.MODEL.STD), + ]) def forward(self, x, return_embedding=False): # C T H W shoulder = self.model(x, input_type="video") @@ -39,8 +59,16 @@ def forward(self, x, return_embedding=False): # C T H W if return_embedding: return y, shoulder return y - + def prepare_image(self, im, bgr2rgb = True): + # 1,C,H,W + if isinstance(im, Image.Image): + im = np.array(im) + + im = self.transform(im).float() + return im + + def prepare_image_v2(self, im, bgr2rgb = True): # 1,C,H,W im = prepare_image(im, self.cfg.MODEL.MEAN, self.cfg.MODEL.STD, self.cfg.MODEL.IN_SIZE, bgr2rgb) return im diff --git a/example/Perception_examples.ipynb b/example/Perception_examples.ipynb new file mode 100644 index 0000000..13a97b8 --- /dev/null +++ b/example/Perception_examples.ipynb @@ -0,0 +1,365 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "d4401ccb-8af6-4a95-bc60-4d5c5a2e603c", + "metadata": {}, + "outputs": [], + "source": [ + "## Follow the instructions to install the code https://github.com/VIDA-NYU/Perception-training\n", + "## Also install the jupyter https://pypi.org/project/jupyter/\n", + "## You'll not be able to run with a GPU with less than 12GB of memory\n", + "## It's going to take too long to run with CPU " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "71456c91-8731-45a7-92c8-943eadc12b0a", + "metadata": {}, + "outputs": [], + "source": [ + "1. Evaluating with dataloader and model (interface used for training evaluation)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "6e6b51bd-5dbf-4f00-9152-ffb796c5d22b", + "metadata": {}, + "outputs": [], + "source": [ + "import math, torch, cv2, matplotlib.pyplot as plt\n", + "from torch.utils.data import DataLoader\n", + "\n", + "from step_recog import evaluate, build_model\n", + "from step_recog.config import load_config\n", + "from step_recog.datasets import Milly_multifeature_v4, collate_fn\n", + "\n", + "def args_hook(cfg_file):\n", + " args = lambda: None\n", + " args.cfg_file = cfg_file\n", + " args.opts = None \n", + " return args" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "250c1ef8-8557-4349-8048-1de0c342722f", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Video: 100%|██████████| 1/1 [00:00<00:00, 29.28it/s, window total=83, padded videos=0]\n", + "Using cache found in /home/../.cache/torch/hub/facebookresearch_omnivore_main\n", + "/ext3/miniconda3/lib/python3.12/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "/ext3/miniconda3/lib/python3.12/site-packages/torch/functional.py:512: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3587.)\n", + " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cfg = args_hook(\"config/M2.yaml\")\n", + "cfg = load_config(cfg)\n", + "\n", + "#reads frames and returns features\n", + "dataset = Milly_multifeature_v4(cfg, split='test')\n", + "data_loader = DataLoader(\n", + " dataset, \n", + " shuffle=False, \n", + " batch_size=cfg.TRAIN.BATCH_SIZE,\n", + " collate_fn=collate_fn,\n", + " drop_last=False)\n", + "\n", + "model, _ = build_model(cfg)\n", + "weights = torch.load(\"models/M2.pt\")\n", + "model.load_state_dict(weights)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "7e41e03b-a2ed-4dd4-b329-cee073dee8fd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "WARNING ⚠️ NMS time limit 0.550s exceeded\n" + ] + } + ], + "source": [ + "#saves classification report, confusion matrix, and video evaluation (image with ground-truth, predicted and confidence)\n", + "evaluate(model, data_loader, cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "04e691d0-9a07-4863-b98e-e3171b2be462", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "cm = cv2.imread(\"output/confusion_matrix.png\")\n", + "cm = cv2.cvtColor(cm, cv2.COLOR_BGR2RGB)\n", + "\n", + "video_eval = cv2.imread(\"output/video_evaluation/M2-19-step_variation.png\")\n", + "video_eval = cv2.cvtColor(video_eval, cv2.COLOR_BGR2RGB)\n", + "\n", + "figure = plt.figure(figsize = (1366 / 100, 768 / 100), dpi = 100) \n", + "plt.subplot(1, 2, 1)\n", + "plt.imshow(cm)\n", + "plt.axis('off')\n", + "\n", + "plt.subplot(1, 2, 2)\n", + "plt.imshow(video_eval)\n", + "plt.axis('off')\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "1d0f1446-284a-4037-891f-817a96f4bbff", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " precision recall f1-score support\n", + "\n", + " Step 1 1.00 1.00 1.00 10\n", + " Step 2 0.38 1.00 0.56 5\n", + " Step 3 0.50 0.14 0.22 7\n", + " Step 4 0.75 0.86 0.80 7\n", + " Step 5 0.33 1.00 0.50 5\n", + " Step 6 0.00 0.00 0.00 0\n", + " Step 7 0.00 0.00 0.00 0\n", + " Step 8 0.60 1.00 0.75 3\n", + " No step 0.97 0.63 0.76 46\n", + "\n", + " accuracy 0.71 83\n", + " macro avg 0.50 0.63 0.51 83\n", + "weighted avg 0.83 0.71 0.72 83\n", + "\n", + "\n", + "Categorical accuracy: 0.71\n", + "Weighted accuracy: 0.80\n", + "Balanced accuracy: 0.80\n", + "\n" + ] + } + ], + "source": [ + "metrics = open(\"output/metrics.txt\")\n", + "metrics = metrics.read()\n", + "print(metrics)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7522c4c4-0745-4ed4-8b7f-5809012c60b7", + "metadata": {}, + "outputs": [], + "source": [ + "2. Evaluating with StepPredictor and ProcedureStateMachine (interface used by BBN)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "258cc394-f256-41a9-9ff8-dcc151fb22f3", + "metadata": {}, + "outputs": [], + "source": [ + "import supervision as sv, tqdm, torch, numpy as np, ipdb, pandas as pd\n", + "\n", + "from step_recog.full.model import StepPredictor\n", + "from step_recog.full.statemachine import ProcedureStateMachine" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "4a8c9108-35de-4dc7-bd24-781880316ce4", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using cache found in /home/../.cache/torch/hub/facebookresearch_omnivore_main\n" + ] + } + ], + "source": [ + "video_path = \"videos/M2-19.mp4\"\n", + "video_info = sv.VideoInfo.from_video_path(video_path)\n", + "step_process = video_info.fps #1 second by default\n", + "\n", + "model = StepPredictor(\"config/M2.yaml\", video_info.fps).to(\"cuda\")\n", + "psm = ProcedureStateMachine(model.cfg.MODEL.OUTPUT_DIM)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "1fa51108-c221-4b15-9967-99a215432009", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "1870it [00:24, 74.98it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " window_end_frame window_end_sec step_idx step_desc \\\n", + "0 0 0.000000 8 No step \n", + "1 1 0.033333 8 No step \n", + "2 2 0.066667 8 No step \n", + "3 3 0.100000 8 No step \n", + "4 4 0.133333 8 No step \n", + "... ... ... ... ... \n", + "1865 1865 62.166667 8 No step \n", + "1866 1866 62.200000 8 No step \n", + "1867 1867 62.233333 8 No step \n", + "1868 1868 62.266667 8 No step \n", + "1869 1869 62.300000 8 No step \n", + "\n", + " step_state \n", + "0 [0, 0, 0, 0, 0, 0, 0, 0] \n", + "1 [0, 0, 0, 0, 0, 0, 0, 0] \n", + "2 [0, 0, 0, 0, 0, 0, 0, 0] \n", + "3 [0, 0, 0, 0, 0, 0, 0, 0] \n", + "4 [0, 0, 0, 0, 0, 0, 0, 0] \n", + "... ... \n", + "1865 [2, 2, 2, 2, 2, 0, 0, 1] \n", + "1866 [2, 2, 2, 2, 2, 0, 0, 1] \n", + "1867 [2, 2, 2, 2, 2, 0, 0, 1] \n", + "1868 [2, 2, 2, 2, 2, 0, 0, 1] \n", + "1869 [2, 2, 2, 2, 2, 0, 0, 1] \n", + "\n", + "[1870 rows x 5 columns]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "prob_step = np.zeros(model.cfg.MODEL.OUTPUT_DIM + 1)\n", + "prob_step[-1] = 1.0\n", + "step_desc = \"No step\"\n", + "\n", + "window_prediction = {\"window_end_frame\": [], \"window_end_sec\": [], \"step_idx\": [], \"step_desc\": [], \"step_state\": []}\n", + "model.reset()\n", + "psm.reset()\n", + "\n", + "#Iterates over one video\n", + "for idx, frame in tqdm.tqdm(enumerate(sv.get_video_frames_generator(video_path))): \n", + " frame_aux = model.prepare(frame)\n", + " model.queue_frame(frame_aux)\n", + "\n", + " if idx % step_process == 0:\n", + " prob_step = model(frame_aux, queue_omni_frame = False).cpu().squeeze().numpy() \n", + " step_idx = np.argmax(prob_step)\n", + " step_desc = \"No step\" if step_idx >= len(model.STEPS) else model.STEPS[step_idx] \n", + "\n", + " #Most important output for BBN personnel\n", + " psm.process_timestep(prob_step)\n", + "\n", + " window_prediction[\"window_end_frame\"].append(idx)\n", + " window_prediction[\"window_end_sec\"].append(idx / video_info.fps)\n", + " window_prediction[\"step_idx\"].append(step_idx)\n", + " window_prediction[\"step_desc\"].append(step_desc)\n", + " current_state = psm.current_state.copy()\n", + " window_prediction[\"step_state\"].append(current_state)\n", + "\n", + "window_prediction = pd.DataFrame(window_prediction)\n", + "print(window_prediction)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5d349617-bae8-4714-bdfa-885a77922c26", + "metadata": {}, + "outputs": [], + "source": [ + "3. Evaluating with visualize.py (it uses StepPredictor and ProcedureStateMachine)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7f711e4-5ee6-48c0-a7d4-52eb91f52777", + "metadata": {}, + "outputs": [], + "source": [ + "#Plots a histogram on the lef-top of the video with: step probatilities and states\n", + "!python ../step_recog/full/visualize.py videos/M2-19.mp4 output/out.mp4 ../config/M2.yaml" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "sing", + "language": "python", + "name": "sing" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/example/config/M2.yaml b/example/config/M2.yaml new file mode 100644 index 0000000..53558ae --- /dev/null +++ b/example/config/M2.yaml @@ -0,0 +1,28 @@ +_BASE_: STEPGRU_BASE.yaml +MODEL: + OMNIGRU_CHECKPOINT_URL: 'models/M2.pt' + OUTPUT_DIM: 8 + YOLO_CHECKPOINT_URL: 'models/bbn_yolo_M2.pt' + +DATASET: + TR_ANNOTATIONS_FILE: "labels/M2_Alabama+BBN_videos_M2-19.csv" + VL_ANNOTATIONS_FILE: "labels/M2_Alabama+BBN_videos_M2-19.csv" + TS_ANNOTATIONS_FILE: "labels/M2_Alabama+BBN_videos_M2-19.csv" + +OUTPUT: + LOCATION: "output" + +TRAIN: + ENABLE: False + +SKILLS: + - NAME: M2 - Apply Tourniquet + STEPS: + - Place tourniquet over affected extremity 2-3 inches above wound site. + - Pull tourniquet tight. + - Apply strap to strap body. + - Turn windless clock wise or counter clockwise until hemorrhage is controlled. + - Lock windless into the windless keeper. + - Pull remaining strap over the windless keeper. + - Secure strap and windless keeper with keeper securing device. + - Mark time on securing device strap with permanent marker. diff --git a/example/config/OMNIVORE.yaml b/example/config/OMNIVORE.yaml new file mode 100644 index 0000000..c9945bc --- /dev/null +++ b/example/config/OMNIVORE.yaml @@ -0,0 +1,6 @@ +MODEL: + ARCH: omnivore_swinB_epic + MODEL_NAME: Omnivore + NFRAMES: 32 + MEAN: [0.485, 0.456, 0.406] + STD: [0.229, 0.224, 0.225] \ No newline at end of file diff --git a/example/config/SLOWFAST_R50.yaml b/example/config/SLOWFAST_R50.yaml new file mode 100644 index 0000000..7d3cce9 --- /dev/null +++ b/example/config/SLOWFAST_R50.yaml @@ -0,0 +1,68 @@ +TRAIN: + ENABLE: False + DATASET: epickitchens + BATCH_SIZE: 64 + EVAL_PERIOD: 2 + CHECKPOINT_PERIOD: 1 + CHECKPOINT_EPOCH_RESET: True + AUTO_RESUME: True + CHECKPOINT_FILE_PATH: "/home/user/data/SLOWFAST-AUDITORY/SLOWFAST_EPIC.pyth" +DATA: + INPUT_CHANNEL_NUM: [1, 1] +AUDIO_DATA: + CLIP_SECS: 1.999 + NUM_FRAMES: 400 +SLOWFAST: + ALPHA: 4 + BETA_INV: 8 + FUSION_CONV_CHANNEL_RATIO: 2 + FUSION_KERNEL_SZ: 7 +RESNET: + ZERO_INIT_FINAL_BN: True + WIDTH_PER_GROUP: 64 + NUM_GROUPS: 1 + DEPTH: 50 + TRANS_FUNC: bottleneck_transform + STRIDE_1X1: False + NUM_BLOCK_TEMP_KERNEL: [[3, 3], [4, 4], [6, 6], [3, 3]] + FREQUENCY_STRIDES: [[1, 1], [2, 2], [2, 2], [2, 2]] + FREQUENCY_DILATIONS: [[1, 1], [1, 1], [1, 1], [1, 1]] +BN: + USE_PRECISE_STATS: True + FREEZE: True + NUM_BATCHES_PRECISE: 200 +SOLVER: + BASE_LR: 0.001 + LR_POLICY: steps_with_relative_lrs + STEPS: [0, 20, 25] + LRS: [1, 0.1, 0.01] + MAX_EPOCH: 30 + MOMENTUM: 0.9 + WEIGHT_DECAY: 1e-4 + WARMUP_EPOCHS: -1.0 + WARMUP_START_LR: 0.01 + OPTIMIZING_METHOD: sgd +MODEL: + NUM_CLASSES: [34, 34] + ARCH: slowfast + MODEL_NAME: SlowFast + LOSS_FUNC: cross_entropy + DROPOUT_RATE: 0.5 +TEST: + ENABLE: False + DATASET: epickitchens + BATCH_SIZE: 32 + NUM_ENSEMBLE_VIEWS: 1 +DATA_LOADER: + NUM_WORKERS: 8 + PIN_MEMORY: True +EPICKITCHENS: + TRAIN_PLUS_VAL: False + AUDIO_DATA_FILE: "/home/user/data/BBN/new/M1/sound/files/BBN-M1-audio-windows_with_epic-structure.hdf5" + ANNOTATIONS_DIR: "/home/user/data/BBN/new/M1/sound/files" +NUM_GPUS: 1 +NUM_SHARDS: 1 +RNG_SEED: 0 +OUTPUT_DIR: "/home/user/data/BBN/new/M1/sound" +EXTRACT: + ENABLE: True diff --git a/example/config/STEPGRU_BASE.yaml b/example/config/STEPGRU_BASE.yaml new file mode 100644 index 0000000..188128e --- /dev/null +++ b/example/config/STEPGRU_BASE.yaml @@ -0,0 +1,31 @@ +MODEL: + HIDDEN_SIZE: 1024 + CONTEXT_LENGTH: 'full' + USE_ACTION: True ##default true + USE_OBJECTS: True ##default true + USE_AUDIO: False ##default false + USE_BN: False ##default false + DROP_OUT: 0.5 + + OMNIVORE_CONFIG: 'config/OMNIVORE.yaml' + SLOWFAST_CONFIG: 'config/SLOWFAST_R50.yaml' +DATASET: + NAME: 'Milly' + LOCATION: 'videos/frames' + AUDIO_LOCATION: '/sound' + INCLUDE_IMAGE_AUGMENTATIONS: True + INCLUDE_TIME_AUGMENTATIONS: False + IMAGE_AUGMENTATION_PERCENTAGE: 0.8 +DATALOADER: + NUM_WORKERS: 12 + PIN_MEMORY: True +TRAIN: + ENABLE: True + USE_CROSS_VALIDATION: True ##default true + USE_CLASS_WEIGHT: True ##default true + NUM_GPUS: 1 + BATCH_SIZE: 8 #32 + OPT: "adam" #adam sgd rmsprop + LR: 0.001 + EPOCHS: 25 + CV_TEST_TYPE: None # 10p bbn None diff --git a/example/labels/M2_Alabama+BBN_videos_M2-19.csv b/example/labels/M2_Alabama+BBN_videos_M2-19.csv new file mode 100644 index 0000000..bb455b4 --- /dev/null +++ b/example/labels/M2_Alabama+BBN_videos_M2-19.csv @@ -0,0 +1,8 @@ +narration_id,participant_id,video_id,narration_timestamp,start_timestamp,stop_timestamp,start_frame,stop_frame,narration,verb,verb_class,noun,noun_class,all_nouns,all_noun_classes,video_fps +302,M2,M2-19,00:00:00.000,00:00:00.000,00:00:00.000,374,638,Place tourniquet with over effected extremity 2-3 inches above wound site.,Place tourniquet with over effected extremity 2-3 inches above wound site.,0,Place tourniquet with over effected extremity 2-3 inches above wound site.,0,['Place tourniquet with over effected extremity 2-3 inches above wound site.'],[0],30 +303,M2,M2-19,00:00:00.000,00:00:00.000,00:00:00.000,677,785,Pull tourniquet tight.,Pull tourniquet tight.,1,Pull tourniquet tight.,1,['Pull tourniquet tight.'],[1],30 +304,M2,M2-19,00:00:00.000,00:00:00.000,00:00:00.000,806,884,Cinch tourniquet strap.,Cinch tourniquet strap.,2,Cinch tourniquet strap.,2,['Cinch tourniquet strap.'],[2],30 +305,M2,M2-19,00:00:00.000,00:00:00.000,00:00:00.000,896,1072,Turn windless clock wise or counter clockwise until hemorrhage is controlled .,Turn windless clock wise or counter clockwise until hemorrhage is controlled .,3,Turn windless clock wise or counter clockwise until hemorrhage is controlled .,3,['Turn windless clock wise or counter clockwise until hemorrhage is controlled .'],[3],30 +306,M2,M2-19,00:00:00.000,00:00:00.000,00:00:00.000,1187,1247,Cinch tourniquet strap.,Cinch tourniquet strap.,2,Cinch tourniquet strap.,2,['Cinch tourniquet strap.'],[2],30 +307,M2,M2-19,00:00:00.000,00:00:00.000,00:00:00.000,1254,1371,Lock windless into the windless keeper.,Lock windless into the windless keeper.,4,Lock windless into the windless keeper.,4,['Lock windless into the windless keeper.'],[4],30 +308,M2,M2-19,00:00:00.000,00:00:00.000,00:00:00.000,1454,1503,Mark time on securing device strap with permanent marker.,Mark time on securing device strap with permanent marker.,7,Mark time on securing device strap with permanent marker.,7,['Mark time on securing device strap with permanent marker.'],[7],30 diff --git a/example/models/.gitignore b/example/models/.gitignore new file mode 100644 index 0000000..e69de29 diff --git a/example/output/confusion_matrix.png b/example/output/confusion_matrix.png new file mode 100644 index 0000000..1acb8dc Binary files /dev/null and b/example/output/confusion_matrix.png differ diff --git a/example/output/metrics.txt b/example/output/metrics.txt new file mode 100644 index 0000000..06e063f --- /dev/null +++ b/example/output/metrics.txt @@ -0,0 +1,20 @@ + precision recall f1-score support + + Step 1 1.00 1.00 1.00 10 + Step 2 0.38 1.00 0.56 5 + Step 3 0.50 0.14 0.22 7 + Step 4 0.75 0.86 0.80 7 + Step 5 0.33 1.00 0.50 5 + Step 6 0.00 0.00 0.00 0 + Step 7 0.00 0.00 0.00 0 + Step 8 0.60 1.00 0.75 3 + No step 0.97 0.63 0.76 46 + + accuracy 0.71 83 + macro avg 0.50 0.63 0.51 83 +weighted avg 0.83 0.71 0.72 83 + + +Categorical accuracy: 0.71 +Weighted accuracy: 0.80 +Balanced accuracy: 0.80 diff --git a/example/output/video_evaluation/M2-19-step_variation.png b/example/output/video_evaluation/M2-19-step_variation.png new file mode 100644 index 0000000..06b06dc Binary files /dev/null and b/example/output/video_evaluation/M2-19-step_variation.png differ diff --git a/example/videos/.gitignore b/example/videos/.gitignore new file mode 100644 index 0000000..e69de29 diff --git a/example/videos/frames/.gitignore b/example/videos/frames/.gitignore new file mode 100644 index 0000000..e69de29 diff --git a/example/videos/frames/M2-19/.gitignore b/example/videos/frames/M2-19/.gitignore new file mode 100644 index 0000000..e69de29 diff --git a/step_recog/datasets/milly.py b/step_recog/datasets/milly.py index f1a5dbe..36b662b 100644 --- a/step_recog/datasets/milly.py +++ b/step_recog/datasets/milly.py @@ -140,6 +140,8 @@ def __len__(self): from ultralytics import YOLO #from torch.quantization import quantize_dynamic +from torchvision import transforms + from step_recog.full.download import cached_download_file from step_recog.full.clip_patches import ClipPatches @@ -180,6 +182,10 @@ def __init__(self, cfg, split='train', filter=None): self.augment_configs = {} self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + self.transform = transforms.Compose([ + transforms.Resize(self.omni_cfg.MODEL.IN_SIZE), + transforms.CenterCrop(self.omni_cfg.MODEL.IN_SIZE) + ]) if self.cfg.MODEL.USE_OBJECTS: yolo_checkpoint = cached_download_file(cfg.MODEL.YOLO_CHECKPOINT_URL) @@ -192,7 +198,7 @@ def __init__(self, cfg, split='train', filter=None): self.clip_patches.eval() if self.cfg.MODEL.USE_ACTION: - self.omnivore = Omnivore(self.omni_cfg) + self.omnivore = Omnivore(self.omni_cfg, resize = False) self.omnivore.eval() self.sound_cache = deque(maxlen=5) @@ -431,7 +437,7 @@ def _construct_loader(self, split): video_windows = [] previous_stop_frame = 1 - for _, step_ann in vid_ann.iterrows(): + for idx, step_ann in vid_ann.iterrows(): win_size = self.rng.integers(len(win_size_sec)) hop_size = self.rng.integers(len(hop_size_perc)) @@ -537,15 +543,6 @@ def augment_frames(self, frames, frame_ids, video_id): return frames - #Both CLIP and Omnivore resize to 224, 224 - #With this code, Yolo is using the same size - def _resize_img(self, im, expected_size=224): - scale = max(expected_size/im.shape[0], expected_size/im.shape[1]) - im = cv2.resize(im, (0,0), fx=scale, fy=scale) - im, _ = uniform_crop(im, expected_size, 1) - - return im - def _get_sound_cache(self, video, path): sound = None @@ -587,9 +584,10 @@ def _load_frames(self, window): ## frame = cv2.imread(frame_path) ## frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frame = Image.open(frame_path) + frame = self.transform(frame) frame = np.array(frame) - frame = self._resize_img(frame) self.frame_cache[frame_id] = {"frame": frame, "new": True} + window_frames.append(frame) window_frame_ids.append(frame_id) @@ -628,7 +626,7 @@ def _extract_img_features(self, window_frames): def _extract_act_features(self, window_frames): frame_idx = np.linspace(0, len(window_frames) - 1, self.omni_cfg.MODEL.NFRAMES).astype('long') - X_omnivore = [ self.omnivore.prepare_image(frame, bgr2rgb = False) for frame in window_frames ] + X_omnivore = [ self.omnivore.prepare_image(frame) for frame in window_frames ] X_omnivore = torch.stack(list(X_omnivore), dim=1)[None] X_omnivore = X_omnivore[:, :, frame_idx, :, :] _, Z_action = self.omnivore(X_omnivore.to(self.device), return_embedding=True) diff --git a/step_recog/full/clip_patches.py b/step_recog/full/clip_patches.py index 813ed10..36d4795 100644 --- a/step_recog/full/clip_patches.py +++ b/step_recog/full/clip_patches.py @@ -22,12 +22,17 @@ def stack_patches(self, patches): for x in patches ]) - def forward(self, image, xywh=None, patch_shape=None, include_frame=False): + def forward(self, image, xywh=None, patch_shape=None, include_frame=False): + if isinstance(image, Image.Image): + image = np.array(image) + patches = [] if xywh is None else extract_patches(image, xywh, patch_shape) + if include_frame: patches.insert(0, image) if not patches: return torch.zeros((0, 512), device=self._device.device) + X = self.stack_patches(patches) Z = self.model.encode_image(X) return Z diff --git a/step_recog/full/model.py b/step_recog/full/model.py index 83a3cc6..88bb6fa 100644 --- a/step_recog/full/model.py +++ b/step_recog/full/model.py @@ -5,6 +5,8 @@ from ultralytics import YOLO import ipdb import cv2 +from torchvision import transforms +from PIL import Image from act_recog.models import Omnivore from act_recog.config import load_config as act_load_config @@ -45,12 +47,16 @@ def __init__(self, cfg_file, video_fps = 30): for step in skill['STEPS'] ]) self.MAX_OBJECTS = 25 + self.transform = transforms.Compose([ + transforms.Resize(self.omni_cfg.MODEL.IN_SIZE), + transforms.CenterCrop(self.omni_cfg.MODEL.IN_SIZE) + ]) # build model self.head = OmniGRU(self.cfg, load=True) self.head.eval() if self.cfg.MODEL.USE_ACTION: - self.omnivore = Omnivore(self.omni_cfg) + self.omnivore = Omnivore(self.omni_cfg, resize = False) if self.cfg.MODEL.USE_OBJECTS: yolo_checkpoint = cached_download_file(self.cfg.MODEL.YOLO_CHECKPOINT_URL) self.yolo = YOLO(yolo_checkpoint) @@ -80,13 +86,7 @@ def queue_frame(self, image): self.omnivore_input_queue.append(X_omnivore) def prepare(self, im): - expected_size=224 - im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) - scale = max(expected_size/im.shape[0], expected_size/im.shape[1]) - im = cv2.resize(im, (0,0), fx=scale, fy=scale) - im, _ = uniform_crop(im, expected_size, 1) - - return im + return self.transform(Image.fromarray(im)) def forward(self, image, queue_omni_frame = True): # compute yolo diff --git a/tools/run_step_recog.py b/tools/run_step_recog.py index 27de2e5..0fe4c49 100644 --- a/tools/run_step_recog.py +++ b/tools/run_step_recog.py @@ -102,8 +102,10 @@ def my_train_test_split(cfg, videos): videos, video_test = train_test_split(videos, test_size=0.10, random_state=2359) #M5 1030: only with BBN 041624.zip elif "R18" in cfg.SKILLS[0]["NAME"]: videos, video_test = train_test_split(videos, test_size=0.10, random_state=2343) #R18 1740: only with BBN seal_videos.zip - else: #A8, M4, R16, R19 - videos, video_test = train_test_split(videos, test_size=0.10, random_state=1030) + elif "A8" in cfg.SKILLS[0]["NAME"]: + videos, video_test = train_test_split(videos, test_size=0.10, random_state=2329) #A8: + else: #M4, R16, R19 + videos, video_test = train_test_split(videos, test_size=0.10, random_state=1030) return videos, video_test diff --git a/tools/test.py b/tools/test.py deleted file mode 100644 index f727b43..0000000 --- a/tools/test.py +++ /dev/null @@ -1,45 +0,0 @@ -import torch -import pandas as pd -import os -from torch.utils.data import DataLoader - -from step_recog.config import load_config -from step_recog import datasets, build_model, extract_features -from run_step_recog import parse_args, my_train_test_split - -def main(): - """ - Main function to spawn the process. - """ - args = parse_args() - cfg = load_config(args) - - if cfg.DATALOADER.NUM_WORKERS > 0: - torch.multiprocessing.set_start_method('spawn') - - # build the dataset - timeout = 0 - data = pd.read_csv(cfg.DATASET.TS_ANNOTATIONS_FILE) - _, video_test = my_train_test_split(cfg, data.video_id.unique()) - DATASET_CLASS = getattr(datasets, cfg.DATASET.CLASS) - - ts_dataset = DATASET_CLASS(cfg, split='test', filter=video_test) - ts_data_loader = DataLoader( - ts_dataset, - shuffle=False, - batch_size=cfg.TRAIN.BATCH_SIZE, - num_workers=cfg.DATALOADER.NUM_WORKERS, - collate_fn=datasets.collate_fn, - drop_last=False, - timeout=timeout) - - print('Loading the best model to evaluate') - model, _ = build_model(cfg) - weights = torch.load(cfg.MODEL.OMNIGRU_CHECKPOINT_URL) - model.load_state_dict(model.update_version(weights)) - - # cfg.OUTPUT.LOCATION = os.path.join(cfg.OUTPUT.LOCATION, "test") - extract_features(model, ts_data_loader, cfg) - -if __name__ == "__main__": - main()