From 25aa6ecb1ef271ab488421a5b5c1c1d723b253fd Mon Sep 17 00:00:00 2001 From: ly015 Date: Wed, 1 Feb 2023 16:48:37 +0800 Subject: [PATCH 1/3] ae inference align with master --- .../ae_hrnet-w32_8xb24-300e_coco-512x512.py | 16 ++-- demo/bottomup_demo.py | 69 +++++++++++++--- mmpose/codecs/associative_embedding.py | 78 ++++++++++++++----- .../transforms/bottomup_transforms.py | 12 ++- mmpose/models/heads/heatmap_heads/ae_head.py | 12 ++- 5 files changed, 146 insertions(+), 41 deletions(-) diff --git a/configs/body_2d_keypoint/associative_embedding/coco/ae_hrnet-w32_8xb24-300e_coco-512x512.py b/configs/body_2d_keypoint/associative_embedding/coco/ae_hrnet-w32_8xb24-300e_coco-512x512.py index 5adc1aac1a..677b59b73a 100644 --- a/configs/body_2d_keypoint/associative_embedding/coco/ae_hrnet-w32_8xb24-300e_coco-512x512.py +++ b/configs/body_2d_keypoint/associative_embedding/coco/ae_hrnet-w32_8xb24-300e_coco-512x512.py @@ -96,7 +96,7 @@ decoder=dict(codec, heatmap_size=codec['input_size'])), test_cfg=dict( multiscale_test=False, - flip_test=True, + flip_test=False, shift_heatmap=True, restore_heatmap_size=True, align_corners=False)) @@ -113,9 +113,14 @@ dict( type='BottomupResize', input_size=codec['input_size'], - size_factor=32, + size_factor=64, resize_mode='expand'), - dict(type='PackPoseInputs') + dict( + type='PackPoseInputs', + meta_keys=('id', 'img_id', 'img_path', 'crowd_index', 'ori_shape', + 'img_shape', 'input_size', 'input_center', 'input_scale', + 'flip', 'flip_direction', 'flip_indices', 'raw_ann_info', + 'skeleton_links')) ] # data loaders @@ -142,7 +147,7 @@ type=dataset_type, data_root=data_root, data_mode=data_mode, - ann_file='annotations/person_keypoints_val2017.json', + ann_file='annotations/person_keypoints_val2017_tiny_clean.json', data_prefix=dict(img='val2017/'), test_mode=True, pipeline=val_pipeline, @@ -152,7 +157,8 @@ # evaluators val_evaluator = dict( type='CocoMetric', - ann_file=data_root + 'annotations/person_keypoints_val2017.json', + ann_file=data_root + + 'annotations/person_keypoints_val2017_tiny_clean.json', nms_mode='none', score_mode='keypoint', ) diff --git a/demo/bottomup_demo.py b/demo/bottomup_demo.py index 3d6fee7a03..9eab6172d5 100644 --- a/demo/bottomup_demo.py +++ b/demo/bottomup_demo.py @@ -1,7 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. + import mimetypes import os import time +import os.path as osp +import tempfile from argparse import ArgumentParser import cv2 @@ -208,19 +211,61 @@ def main(): cap.release() else: - args.save_predictions = False - raise ValueError( - f'file {os.path.basename(args.input)} has invalid format.') + inputs = [osp.join(args.input, fn) for fn in os.listdir(args.input)] - if args.save_predictions: - with open(args.pred_save_path, 'w') as f: - json.dump( - dict( - meta_info=model.dataset_meta, - instance_info=pred_instances_list), - f, - indent='\t') - print(f'predictions have been saved at {args.pred_save_path}') + for fn in inputs: + + input_type = mimetypes.guess_type(fn)[0].split('/')[0] + if input_type == 'image': + pred_instances = process_one_image( + args, fn, model, visualizer, show_interval=0) + pred_instances_list = split_instances(pred_instances) + + elif input_type == 'video': + tmp_folder = tempfile.TemporaryDirectory() + video = mmcv.VideoReader(fn) + progressbar = mmengine.ProgressBar(len(video)) + video.cvt2frames(tmp_folder.name, show_progress=False) + output_root = args.output_root + args.output_root = tmp_folder.name + pred_instances_list = [] + + for frame_id, img_fname in enumerate(os.listdir(tmp_folder.name)): + pred_instances = process_one_image( + args, + f'{tmp_folder.name}/{img_fname}', + model, + visualizer, + show_interval=1) + progressbar.update() + pred_instances_list.append( + dict( + frame_id=frame_id, + instances=split_instances(pred_instances))) + + if output_root: + mmcv.frames2video( + tmp_folder.name, + f'{output_root}/{os.path.basename(fn)}', + fps=video.fps, + fourcc='mp4v', + show_progress=False) + tmp_folder.cleanup() + + else: + args.save_predictions = False + raise ValueError( + f'file {os.path.basename(fn)} has invalid format.') + + if args.save_predictions: + with open(args.pred_save_path, 'w') as f: + json.dump( + dict( + meta_info=model.dataset_meta, + instance_info=pred_instances_list), + f, + indent='\t') + print(f'predictions have been saved at {args.pred_save_path}') if __name__ == '__main__': diff --git a/mmpose/codecs/associative_embedding.py b/mmpose/codecs/associative_embedding.py index 7e080f1657..0904249134 100644 --- a/mmpose/codecs/associative_embedding.py +++ b/mmpose/codecs/associative_embedding.py @@ -1,10 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. from collections import namedtuple +from copy import deepcopy from itertools import product from typing import Any, List, Optional, Tuple import numpy as np import torch +from mmengine import dump from munkres import Munkres from torch import Tensor @@ -75,7 +77,9 @@ def _init_group(): tag_list=[]) return _group - for i in keypoint_order: + group_history = [] + + for idx, i in enumerate(keypoint_order): # Get all valid candidate of the i-th keypoints valid = vals[i] > val_thr if not valid.any(): @@ -87,12 +91,22 @@ def _init_group(): if len(groups) == 0: # Initialize the group pool for tag, val, loc in zip(tags_i, vals_i, locs_i): + + # Check if the keypoint belongs to existing groups + if len(groups): + prev_tags = np.stack([g.tag_list[0] for g in groups]) + dists = np.linalg.norm(prev_tags - tag, ord=2, axis=1) + if dists.min() < 1: + continue + group = _init_group() group.kpts[i] = loc group.scores[i] = val group.tag_list.append(tag) groups.append(group) + costs_copy = None + matches = None else: # Match keypoints to existing groups groups = groups[:max_groups] @@ -101,17 +115,18 @@ def _init_group(): # Calculate distance matrix between group tags and tag candidates # of the i-th keypoint # Shape: (M', 1, L) , (1, G, L) -> (M', G, L) - diff = tags_i[:, None] - np.array(group_tags)[None] + diff = (tags_i[:, None] - + np.array(group_tags)[None]).astype(np.float64) dists = np.linalg.norm(diff, ord=2, axis=2) num_kpts, num_groups = dists.shape[:2] - # Experimental cost function for keypoint-group matching + # Experimental cost function for keypoint-group matching2 costs = np.round(dists) * 100 - vals_i[..., None] + if num_kpts > num_groups: - padding = np.full((num_kpts, num_kpts - num_groups), - 1e10, - dtype=np.float32) + padding = np.full((num_kpts, num_kpts - num_groups), 1e10) costs = np.concatenate((costs, padding), axis=1) + costs_copy = costs.copy() # Match keypoints and groups by Munkres algorithm matches = munkres.compute(costs) @@ -121,13 +136,30 @@ def _init_group(): # Add the keypoint to the matched group group = groups[group_idx] else: - # Initialize a new group with unmatched keypoint - group = _init_group() - groups.append(group) - - group.kpts[i] = locs_i[kpt_idx] - group.scores[i] = vals_i[kpt_idx] - group.tag_list.append(tags_i[kpt_idx]) + # if dists[kpt_idx].min() < 0.2: + if False: + group = None + else: + # Initialize a new group with unmatched keypoint + group = _init_group() + groups.append(group) + if group is not None: + group.kpts[i] = locs_i[kpt_idx] + group.scores[i] = vals_i[kpt_idx] + group.tag_list.append(tags_i[kpt_idx]) + + out = { + 'idx': idx, + 'i': i, + 'costs': costs_copy, + 'matches': matches, + 'kpts': np.array([g.kpts for g in groups]), + 'scores': np.array([g.scores for g in groups]), + 'tag_list': [np.array(g.tag_list) for g in groups], + } + group_history.append(deepcopy(out)) + + dump(group_history, 'group_history.pkl') groups = groups[:max_groups] if groups: @@ -210,7 +242,7 @@ def __init__( decode_gaussian_kernel: int = 3, decode_keypoint_thr: float = 0.1, decode_tag_thr: float = 1.0, - decode_topk: int = 20, + decode_topk: int = 30, decode_max_instances: Optional[int] = None, ) -> None: super().__init__() @@ -336,6 +368,12 @@ def _get_batch_topk(self, batch_heatmaps: Tensor, batch_tags: Tensor, B, K, H, W = batch_heatmaps.shape L = batch_tags.shape[1] // K + # Heatmap NMS + dump(batch_heatmaps.cpu().numpy(), 'heatmaps.pkl') + batch_heatmaps = batch_heatmap_nms(batch_heatmaps, + self.decode_nms_kernel) + dump(batch_heatmaps.cpu().numpy(), 'heatmaps_nms.pkl') + # shape of topk_val, top_indices: (B, K, TopK) topk_vals, topk_indices = batch_heatmaps.flatten(-2, -1).topk( k, dim=-1) @@ -433,9 +471,8 @@ def _fill_missing_keypoints(self, keypoints: np.ndarray, cost_map = np.round(dist_map) * 100 - heatmaps[k] # H, W y, x = np.unravel_index(np.argmin(cost_map), shape=(H, W)) keypoints[n, k] = [x, y] - keypoint_scores[n, k] = heatmaps[k, y, x] - return keypoints, keypoint_scores + return keypoints def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor ) -> Tuple[List[np.ndarray], List[np.ndarray]]: @@ -457,15 +494,12 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor batch, each is in shape (N, K). It usually represents the confidience of the keypoint prediction """ + B, _, H, W = batch_heatmaps.shape assert batch_tags.shape[0] == B and batch_tags.shape[2:4] == (H, W), ( f'Mismatched shapes of heatmap ({batch_heatmaps.shape}) and ' f'tagging map ({batch_tags.shape})') - # Heatmap NMS - batch_heatmaps = batch_heatmap_nms(batch_heatmaps, - self.decode_nms_kernel) - # Get top-k in each heatmap and and convert to numpy batch_topk_vals, batch_topk_tags, batch_topk_locs = to_numpy( self._get_batch_topk( @@ -489,7 +523,7 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor if keypoints.size > 0: # identify missing keypoints - keypoints, scores = self._fill_missing_keypoints( + keypoints = self._fill_missing_keypoints( keypoints, scores, heatmaps, tags) # refine keypoint coordinates according to heatmap distribution @@ -500,6 +534,8 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor blur_kernel_size=self.decode_gaussian_kernel) else: keypoints = refine_keypoints(keypoints, heatmaps) + # keypoints += 0.75 + keypoints += 0.5 batch_keypoints.append(keypoints) batch_keypoint_scores.append(scores) diff --git a/mmpose/datasets/transforms/bottomup_transforms.py b/mmpose/datasets/transforms/bottomup_transforms.py index c31e0ae17d..d6b67bdf28 100644 --- a/mmpose/datasets/transforms/bottomup_transforms.py +++ b/mmpose/datasets/transforms/bottomup_transforms.py @@ -484,6 +484,7 @@ def transform(self, results: Dict) -> Optional[dict]: output_size=actual_input_size) else: center = np.array([img_w / 2, img_h / 2], dtype=np.float32) + center = np.round(center) scale = np.array([ img_w * padded_input_size[0] / actual_input_size[0], img_h * padded_input_size[1] / actual_input_size[1] @@ -495,11 +496,18 @@ def transform(self, results: Dict) -> Optional[dict]: rot=0, output_size=padded_input_size) - _img = cv2.warpAffine( - img, warp_mat, padded_input_size, flags=cv2.INTER_LINEAR) + _img = cv2.warpAffine(img, warp_mat, padded_input_size) imgs.append(_img) + # print('#' * 20) + # print('w,h: ', img_w, img_h, 'center: ', center, 'scale: ', + # scale, + # 'actual_input_size: ', actual_input_size, + # 'padded_input_size: ', padded_input_size) + # print(warp_mat) + # print('#' * 20) + # Store the transform information w.r.t. the main input size if i == 0: results['img_shape'] = padded_input_size[::-1] diff --git a/mmpose/models/heads/heatmap_heads/ae_head.py b/mmpose/models/heads/heatmap_heads/ae_head.py index bd12d57a33..451df0bbab 100644 --- a/mmpose/models/heads/heatmap_heads/ae_head.py +++ b/mmpose/models/heads/heatmap_heads/ae_head.py @@ -2,6 +2,7 @@ from typing import List, Optional, Sequence, Tuple, Union import torch +import torch.nn.functional as F from mmengine.structures import PixelData from mmengine.utils import is_list_of from torch import Tensor @@ -110,7 +111,7 @@ def predict(self, # TTA: multi-scale test assert is_list_of(feats, list if flip_test else tuple) else: - assert is_list_of(feats, tuple if flip_test else Tensor) + assert isinstance(feats, list if flip_test else tuple) feats = [feats] # resize heatmaps to align with with input size @@ -129,6 +130,15 @@ def predict(self, for scale_idx, _feats in enumerate(feats): if not flip_test: _heatmaps, _tags = self.forward(_feats) + if heatmap_size: + _heatmaps = F.interpolate( + _heatmaps, (img_h, img_w), + mode='bilinear', + align_corners=align_corners) + _tags = F.interpolate( + _tags, (img_h, img_w), + mode='bilinear', + align_corners=align_corners) else: # TTA: flip test From e43ccce92c1ac32750685c7da3a2a21c74a15123 Mon Sep 17 00:00:00 2001 From: ly015 Date: Tue, 7 Feb 2023 18:44:06 +0800 Subject: [PATCH 2/3] remove center rounding in bottom-up affine --- .../ae_hrnet-w32_8xb24-300e_coco-512x512.py | 5 +-- mmpose/codecs/associative_embedding.py | 44 +++++++++++-------- .../transforms/bottomup_transforms.py | 2 +- 3 files changed, 28 insertions(+), 23 deletions(-) diff --git a/configs/body_2d_keypoint/associative_embedding/coco/ae_hrnet-w32_8xb24-300e_coco-512x512.py b/configs/body_2d_keypoint/associative_embedding/coco/ae_hrnet-w32_8xb24-300e_coco-512x512.py index 677b59b73a..988db0dc7b 100644 --- a/configs/body_2d_keypoint/associative_embedding/coco/ae_hrnet-w32_8xb24-300e_coco-512x512.py +++ b/configs/body_2d_keypoint/associative_embedding/coco/ae_hrnet-w32_8xb24-300e_coco-512x512.py @@ -147,7 +147,7 @@ type=dataset_type, data_root=data_root, data_mode=data_mode, - ann_file='annotations/person_keypoints_val2017_tiny_clean.json', + ann_file='annotations/person_keypoints_val2017.json', data_prefix=dict(img='val2017/'), test_mode=True, pipeline=val_pipeline, @@ -157,8 +157,7 @@ # evaluators val_evaluator = dict( type='CocoMetric', - ann_file=data_root + - 'annotations/person_keypoints_val2017_tiny_clean.json', + ann_file=data_root + 'annotations/person_keypoints_val2017.json', nms_mode='none', score_mode='keypoint', ) diff --git a/mmpose/codecs/associative_embedding.py b/mmpose/codecs/associative_embedding.py index 0904249134..9c9a1f0a6e 100644 --- a/mmpose/codecs/associative_embedding.py +++ b/mmpose/codecs/associative_embedding.py @@ -1,12 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. from collections import namedtuple -from copy import deepcopy +# from copy import deepcopy from itertools import product from typing import Any, List, Optional, Tuple import numpy as np import torch -from mmengine import dump +# from mmengine import dump from munkres import Munkres from torch import Tensor @@ -77,7 +77,7 @@ def _init_group(): tag_list=[]) return _group - group_history = [] + # group_history = [] for idx, i in enumerate(keypoint_order): # Get all valid candidate of the i-th keypoints @@ -105,7 +105,7 @@ def _init_group(): group.tag_list.append(tag) groups.append(group) - costs_copy = None + # costs_copy = None matches = None else: # Match keypoints to existing groups @@ -126,7 +126,7 @@ def _init_group(): if num_kpts > num_groups: padding = np.full((num_kpts, num_kpts - num_groups), 1e10) costs = np.concatenate((costs, padding), axis=1) - costs_copy = costs.copy() + # costs_copy = costs.copy() # Match keypoints and groups by Munkres algorithm matches = munkres.compute(costs) @@ -148,18 +148,18 @@ def _init_group(): group.scores[i] = vals_i[kpt_idx] group.tag_list.append(tags_i[kpt_idx]) - out = { - 'idx': idx, - 'i': i, - 'costs': costs_copy, - 'matches': matches, - 'kpts': np.array([g.kpts for g in groups]), - 'scores': np.array([g.scores for g in groups]), - 'tag_list': [np.array(g.tag_list) for g in groups], - } - group_history.append(deepcopy(out)) + # out = { + # 'idx': idx, + # 'i': i, + # 'costs': costs_copy, + # 'matches': matches, + # 'kpts': np.array([g.kpts for g in groups]), + # 'scores': np.array([g.scores for g in groups]), + # 'tag_list': [np.array(g.tag_list) for g in groups], + # } + # group_history.append(deepcopy(out)) - dump(group_history, 'group_history.pkl') + # dump(group_history, 'group_history.pkl') groups = groups[:max_groups] if groups: @@ -369,10 +369,10 @@ def _get_batch_topk(self, batch_heatmaps: Tensor, batch_tags: Tensor, L = batch_tags.shape[1] // K # Heatmap NMS - dump(batch_heatmaps.cpu().numpy(), 'heatmaps.pkl') + # dump(batch_heatmaps.cpu().numpy(), 'heatmaps.pkl') batch_heatmaps = batch_heatmap_nms(batch_heatmaps, self.decode_nms_kernel) - dump(batch_heatmaps.cpu().numpy(), 'heatmaps_nms.pkl') + # dump(batch_heatmaps.cpu().numpy(), 'heatmaps_nms.pkl') # shape of topk_val, top_indices: (B, K, TopK) topk_vals, topk_indices = batch_heatmaps.flatten(-2, -1).topk( @@ -534,7 +534,13 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor blur_kernel_size=self.decode_gaussian_kernel) else: keypoints = refine_keypoints(keypoints, heatmaps) - # keypoints += 0.75 + # The following 0.5-pixel shift is adapted from mmpose 0.x + # where the heatmap center is calculated by a biased + # rounding ``mu=[int(x), int(y)]``. We keep this shift + # operation for now to to compatible with 0.x checkpoints + # In mmpose 1.x, AE heatmap center is calculated by the + # unbiased rounding ``mu=[int(x+0.5), int(y+0.5)], so the + # following shift will be removed in the future. keypoints += 0.5 batch_keypoints.append(keypoints) diff --git a/mmpose/datasets/transforms/bottomup_transforms.py b/mmpose/datasets/transforms/bottomup_transforms.py index d6b67bdf28..1355d3359a 100644 --- a/mmpose/datasets/transforms/bottomup_transforms.py +++ b/mmpose/datasets/transforms/bottomup_transforms.py @@ -484,7 +484,7 @@ def transform(self, results: Dict) -> Optional[dict]: output_size=actual_input_size) else: center = np.array([img_w / 2, img_h / 2], dtype=np.float32) - center = np.round(center) + # center = np.round(center) scale = np.array([ img_w * padded_input_size[0] / actual_input_size[0], img_h * padded_input_size[1] / actual_input_size[1] From ede84fb74ad82946c6900f94d1caf05fdba1e093 Mon Sep 17 00:00:00 2001 From: lupeng Date: Sat, 10 Jun 2023 02:15:57 +0800 Subject: [PATCH 3/3] fix bottomup demo --- demo/bottomup_demo.py | 69 ++++++++----------------------------------- 1 file changed, 12 insertions(+), 57 deletions(-) diff --git a/demo/bottomup_demo.py b/demo/bottomup_demo.py index 9eab6172d5..3d6fee7a03 100644 --- a/demo/bottomup_demo.py +++ b/demo/bottomup_demo.py @@ -1,10 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. - import mimetypes import os import time -import os.path as osp -import tempfile from argparse import ArgumentParser import cv2 @@ -211,61 +208,19 @@ def main(): cap.release() else: - inputs = [osp.join(args.input, fn) for fn in os.listdir(args.input)] - - for fn in inputs: - - input_type = mimetypes.guess_type(fn)[0].split('/')[0] - if input_type == 'image': - pred_instances = process_one_image( - args, fn, model, visualizer, show_interval=0) - pred_instances_list = split_instances(pred_instances) - - elif input_type == 'video': - tmp_folder = tempfile.TemporaryDirectory() - video = mmcv.VideoReader(fn) - progressbar = mmengine.ProgressBar(len(video)) - video.cvt2frames(tmp_folder.name, show_progress=False) - output_root = args.output_root - args.output_root = tmp_folder.name - pred_instances_list = [] - - for frame_id, img_fname in enumerate(os.listdir(tmp_folder.name)): - pred_instances = process_one_image( - args, - f'{tmp_folder.name}/{img_fname}', - model, - visualizer, - show_interval=1) - progressbar.update() - pred_instances_list.append( - dict( - frame_id=frame_id, - instances=split_instances(pred_instances))) - - if output_root: - mmcv.frames2video( - tmp_folder.name, - f'{output_root}/{os.path.basename(fn)}', - fps=video.fps, - fourcc='mp4v', - show_progress=False) - tmp_folder.cleanup() - - else: - args.save_predictions = False - raise ValueError( - f'file {os.path.basename(fn)} has invalid format.') + args.save_predictions = False + raise ValueError( + f'file {os.path.basename(args.input)} has invalid format.') - if args.save_predictions: - with open(args.pred_save_path, 'w') as f: - json.dump( - dict( - meta_info=model.dataset_meta, - instance_info=pred_instances_list), - f, - indent='\t') - print(f'predictions have been saved at {args.pred_save_path}') + if args.save_predictions: + with open(args.pred_save_path, 'w') as f: + json.dump( + dict( + meta_info=model.dataset_meta, + instance_info=pred_instances_list), + f, + indent='\t') + print(f'predictions have been saved at {args.pred_save_path}') if __name__ == '__main__':