diff --git a/README.md b/README.md index 951c4adf2e..c40b9cdc4c 100644 --- a/README.md +++ b/README.md @@ -342,21 +342,20 @@ This project is released under the [Apache 2.0 license](LICENSE). - [MMEngine](https://github.com/open-mmlab/mmengine): OpenMMLab foundational library for training deep learning models. - [MMCV](https://github.com/open-mmlab/mmcv): OpenMMLab foundational library for computer vision. -- [MIM](https://github.com/open-mmlab/mim): MIM installs OpenMMLab packages. -- [MMClassification](https://github.com/open-mmlab/mmclassification): OpenMMLab image classification toolbox and benchmark. +- [MMPreTrain](https://github.com/open-mmlab/mmpretrain): OpenMMLab pre-training toolbox and benchmark. +- [MMagic](https://github.com/open-mmlab/mmagic): Open**MM**Lab **A**dvanced, **G**enerative and **I**ntelligent **C**reation toolbox. - [MMDetection](https://github.com/open-mmlab/mmdetection): OpenMMLab detection toolbox and benchmark. - [MMDetection3D](https://github.com/open-mmlab/mmdetection3d): OpenMMLab's next-generation platform for general 3D object detection. - [MMRotate](https://github.com/open-mmlab/mmrotate): OpenMMLab rotated object detection toolbox and benchmark. +- [MMTracking](https://github.com/open-mmlab/mmtracking): OpenMMLab video perception toolbox and benchmark. - [MMSegmentation](https://github.com/open-mmlab/mmsegmentation): OpenMMLab semantic segmentation toolbox and benchmark. - [MMOCR](https://github.com/open-mmlab/mmocr): OpenMMLab text detection, recognition, and understanding toolbox. - [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab pose estimation toolbox and benchmark. - [MMHuman3D](https://github.com/open-mmlab/mmhuman3d): OpenMMLab 3D human parametric model toolbox and benchmark. -- [MMSelfSup](https://github.com/open-mmlab/mmselfsup): OpenMMLab self-supervised learning toolbox and benchmark. -- [MMRazor](https://github.com/open-mmlab/mmrazor): OpenMMLab model compression toolbox and benchmark. - [MMFewShot](https://github.com/open-mmlab/mmfewshot): OpenMMLab fewshot learning toolbox and benchmark. - [MMAction2](https://github.com/open-mmlab/mmaction2): OpenMMLab's next-generation action understanding toolbox and benchmark. -- [MMTracking](https://github.com/open-mmlab/mmtracking): OpenMMLab video perception toolbox and benchmark. - [MMFlow](https://github.com/open-mmlab/mmflow): OpenMMLab optical flow toolbox and benchmark. -- [MMEditing](https://github.com/open-mmlab/mmediting): OpenMMLab image and video editing toolbox. -- [MMGeneration](https://github.com/open-mmlab/mmgeneration): OpenMMLab image and video generative models toolbox. - [MMDeploy](https://github.com/open-mmlab/mmdeploy): OpenMMLab Model Deployment Framework. +- [MMRazor](https://github.com/open-mmlab/mmrazor): OpenMMLab model compression toolbox and benchmark. +- [MIM](https://github.com/open-mmlab/mim): MIM installs OpenMMLab packages. +- [Playground](https://github.com/open-mmlab/playground): A central hub for gathering and showcasing amazing projects built upon OpenMMLab. diff --git a/README_CN.md b/README_CN.md index 49a956cab9..519e9889da 100644 --- a/README_CN.md +++ b/README_CN.md @@ -339,24 +339,23 @@ MMPose 是一款由不同学校和公司共同贡献的开源项目。我们感 - [MMEngine](https://github.com/open-mmlab/mmengine): OpenMMLab 深度学习模型训练基础库 - [MMCV](https://github.com/open-mmlab/mmcv): OpenMMLab 计算机视觉基础库 -- [MIM](https://github.com/open-mmlab/mim): OpenMMlab 项目、算法、模型的统一入口 -- [MMClassification](https://github.com/open-mmlab/mmclassification): OpenMMLab 图像分类工具箱 +- [MMPreTrain](https://github.com/open-mmlab/mmpretrain): OpenMMLab 深度学习预训练工具箱 +- [MMagic](https://github.com/open-mmlab/mmagic): OpenMMLab 新一代人工智能内容生成(AIGC)工具箱 - [MMDetection](https://github.com/open-mmlab/mmdetection): OpenMMLab 目标检测工具箱 - [MMDetection3D](https://github.com/open-mmlab/mmdetection3d): OpenMMLab 新一代通用 3D 目标检测平台 - [MMRotate](https://github.com/open-mmlab/mmrotate): OpenMMLab 旋转框检测工具箱与测试基准 +- [MMTracking](https://github.com/open-mmlab/mmtracking): OpenMMLab 一体化视频目标感知平台 - [MMSegmentation](https://github.com/open-mmlab/mmsegmentation): OpenMMLab 语义分割工具箱 - [MMOCR](https://github.com/open-mmlab/mmocr): OpenMMLab 全流程文字检测识别理解工具包 - [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab 姿态估计工具箱 - [MMHuman3D](https://github.com/open-mmlab/mmhuman3d): OpenMMLab 人体参数化模型工具箱与测试基准 -- [MMSelfSup](https://github.com/open-mmlab/mmselfsup): OpenMMLab 自监督学习工具箱与测试基准 -- [MMRazor](https://github.com/open-mmlab/mmrazor): OpenMMLab 模型压缩工具箱与测试基准 - [MMFewShot](https://github.com/open-mmlab/mmfewshot): OpenMMLab 少样本学习工具箱与测试基准 - [MMAction2](https://github.com/open-mmlab/mmaction2): OpenMMLab 新一代视频理解工具箱 -- [MMTracking](https://github.com/open-mmlab/mmtracking): OpenMMLab 一体化视频目标感知平台 - [MMFlow](https://github.com/open-mmlab/mmflow): OpenMMLab 光流估计工具箱与测试基准 -- [MMEditing](https://github.com/open-mmlab/mmediting): OpenMMLab 图像视频编辑工具箱 -- [MMGeneration](https://github.com/open-mmlab/mmgeneration): OpenMMLab 图片视频生成模型工具箱 - [MMDeploy](https://github.com/open-mmlab/mmdeploy): OpenMMLab 模型部署框架 +- [MMRazor](https://github.com/open-mmlab/mmrazor): OpenMMLab 模型压缩工具箱与测试基准 +- [MIM](https://github.com/open-mmlab/mim): OpenMMlab 项目、算法、模型的统一入口 +- [Playground](https://github.com/open-mmlab/playground): 收集和展示 OpenMMLab 相关的前沿、有趣的社区项目 ## 欢迎加入 OpenMMLab 社区 diff --git a/configs/animal_2d_keypoint/rtmpose/animalpose/rtmpose-animalpose.yml b/configs/animal_2d_keypoint/rtmpose/animalpose/rtmpose-animalpose.yml new file mode 100644 index 0000000000..b7ef97c395 --- /dev/null +++ b/configs/animal_2d_keypoint/rtmpose/animalpose/rtmpose-animalpose.yml @@ -0,0 +1,89 @@ +Models: +- Config: configs/animal_2d_keypoint/rtmpose/animalpose/rtmpose-l_8xb64-210e_animalpose-256x256.py + In Collection: RTMPose + Metadata: + Architecture: + - RTMPose + Training Data: Animal-Pose + Name: rtmpose-l_8xb64-210e_animalpose-256x256 + Results: + - Dataset: Animal-Pose + Metrics: + AP: 0.766 + AP@0.5: 0.959 + AP@0.75: 0.855 + AP (M): 0.725 + AP (L): 0.778 + AR: 0.800 + AR@0.5: 0.968 + AR@0.75: 0.874 + AR (M): 0.769 + AR (L): 0.808 + Task: Animal 2D Keypoint +# Weights: https://download.openmmlab.com/mmpose/animal/hrnet/hrnet_w32_animalpose_256x256-1aa7f075_20210426.pth +- Config: configs/animal_2d_keypoint/rtmpose/animalpose/rtmpose-m_8xb64-210e_animalpose-256x256.py + In Collection: RTMPose + Metadata: + Architecture: + - RTMPose + Training Data: Animal-Pose + Name: rtmpose-m_8xb64-210e_animalpose-256x256 + Results: + - Dataset: Animal-Pose + Metrics: + AP: 0.598 + AP@0.5: 0.896 + AP@0.75: 0.653 + AP (M): 0.596 + AP (L): 0.603 + AR: 0.642 + ARP@0.5: 0.900 + AR@0.75: 0.699 + AR (M): 0.660 + AR (L): 0.641 + Task: Animal 2D Keypoint +# Weights: https://download.openmmlab.com/mmpose/animal/hrnet/hrnet_w32_animalpose_256x256-1aa7f075_20210426.pth +- Config: configs/animal_2d_keypoint/rtmpose/animalpose/rtmpose-s_8xb64-210e_animalpose-256x256.py + In Collection: RTMPose + Metadata: + Architecture: + - RTMPose + Training Data: Animal-Pose + Name: rtmpose-s_8xb64-210e_animalpose-256x256 + Results: + - Dataset: Animal-Pose + Metrics: + AP: 0.709 + AP@0.5: 0.938 + AP@0.75: 0.799 + AP (M): 0.674 + AP (L): 0.718 + AR: 0.748 + ARP@0.5: 0.946 + AR@0.75: 0.824 + AR (M): 0.730 + AR (L): 0.754 + Task: Animal 2D Keypoint +# Weights: https://download.openmmlab.com/mmpose/animal/hrnet/hrnet_w32_animalpose_256x256-1aa7f075_20210426.pth +- Config: configs/animal_2d_keypoint/rtmpose/animalpose/rtmpose-t_8xb64-210e_animalpose-256x256.py + In Collection: RTMPose + Metadata: + Architecture: + - RTMPose + Training Data: Animal-Pose + Name: rtmpose-t_8xb64-210e_animalpose-256x256 + Results: + - Dataset: Animal-Pose + Metrics: + AP: 0.680 + AP@0.5: 0.927 + AP@0.75: 0.770 + AP (M): 0.657 + AP (L): 0.688 + AR: .718 + ARP@0.5: 0.934 + AR@0.75: 0.792 + AR (M): 0.712 + AR (L): 0.721 + Task: Animal 2D Keypoint +# Weights: https://download.openmmlab.com/mmpose/animal/hrnet/hrnet_w32_animalpose_256x256-1aa7f075_20210426.pth diff --git a/configs/animal_2d_keypoint/rtmpose/animalpose/rtmpose-l_8xb64-210e_animalpose-256x256.py b/configs/animal_2d_keypoint/rtmpose/animalpose/rtmpose-l_8xb64-210e_animalpose-256x256.py new file mode 100644 index 0000000000..5d4b017b5b --- /dev/null +++ b/configs/animal_2d_keypoint/rtmpose/animalpose/rtmpose-l_8xb64-210e_animalpose-256x256.py @@ -0,0 +1,172 @@ +_base_ = ['../../../_base_/default_runtime.py'] + +# runtime +max_epochs = 210 +base_lr = 5e-4 +num_workers = 8 + +train_cfg = dict(max_epochs=max_epochs, val_interval=10) +randomness = dict(seed=23) + +# optimizer +optim_wrapper = dict(optimizer=dict( + type='Adam', + lr=5e-4, +)) + +# learning rate +batch_size = 64 +param_scheduler = [ + dict( + type='LinearLR', begin=0, end=500, start_factor=0.001, + by_epoch=False), # warm-up + dict( + type='MultiStepLR', + begin=0, + end=210, + milestones=[170, 200], + gamma=0.1, + by_epoch=True) +] + +# automatically scaling LR based on the actual training batch size +auto_scale_lr = dict(base_batch_size=batch_size) + +# hooks +default_hooks = dict( + checkpoint=dict(save_best='animalpose/AP', rule='greater')) + +# codec settings +codec = dict( + type='SimCCLabel', + input_size=(256, 256), + sigma=(5.66, 5.66), + simcc_split_ratio=2.0, + normalize=False, + use_dark=False) + +# model settings +model = dict( + type='TopdownPoseEstimator', + data_preprocessor=dict( + type='PoseDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True), + backbone=dict( + _scope_='mmdet', + type='CSPNeXt', + arch='P5', + expand_ratio=0.5, + deepen_factor=1., + widen_factor=1., + out_indices=(4, ), + channel_attention=True, + norm_cfg=dict(type='SyncBN'), + act_cfg=dict(type='SiLU'), + init_cfg=dict( + type='Pretrained', + prefix='backbone.', + checkpoint='https://download.openmmlab.com/mmpose/v1/projects/' + 'rtmpose/cspnext-l_udp-aic-coco_210e-256x192-273b7631_20230130.pth' + )), + head=dict( + type='RTMCCHead', + in_channels=1024, + out_channels=20, + input_size=codec['input_size'], + in_featuremap_size=(8, 8), + simcc_split_ratio=codec['simcc_split_ratio'], + final_layer_kernel_size=7, + gau_cfg=dict( + hidden_dims=256, + s=128, + expansion_factor=2, + dropout_rate=0., + drop_path=0., + act_fn='SiLU', + use_rel_bias=False, + pos_enc=False), + loss=dict( + type='KLDiscretLoss', + use_target_weight=True, + beta=10., + label_softmax=True), + decoder=codec), + test_cfg=dict(flip_test=True)) + +# base dataset settings +dataset_type = 'AnimalPoseDataset' +data_mode = 'topdown' +data_root = 'data/animalpose/' + +backend_args = dict(backend='local') + +# pipelines +train_pipeline = [ + dict(type='LoadImage'), + dict(type='GetBBoxCenterScale'), + dict(type='RandomFlip', direction='horizontal'), + dict(type='RandomHalfBody'), + dict(type='RandomBBoxTransform'), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='GenerateTarget', encoder=codec), + dict(type='PackPoseInputs') +] +val_pipeline = [ + dict(type='LoadImage'), + dict(type='GetBBoxCenterScale'), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='PackPoseInputs') +] + +# data loaders +train_dataloader = dict( + batch_size=batch_size, + num_workers=num_workers, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='annotations/animalpose_train.json', + data_prefix=dict(img=''), + pipeline=train_pipeline, + )) +val_dataloader = dict( + batch_size=batch_size, + num_workers=num_workers, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='annotations/animalpose_val.json', + data_prefix=dict(img=''), + test_mode=True, + pipeline=val_pipeline, + )) +test_dataloader = dict( + batch_size=batch_size, + num_workers=num_workers, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='annotations/animalpose_test.json', + data_prefix=dict(img=''), + test_mode=True, + pipeline=val_pipeline, + )) + +# evaluators +val_evaluator = dict( + type='CocoMetric', ann_file=data_root + 'annotations/animalpose_val.json') +test_evaluator = dict( + type='CocoMetric', ann_file=data_root + 'annotations/animalpose_test.json') diff --git a/configs/animal_2d_keypoint/rtmpose/animalpose/rtmpose-m_8xb64-210e_animalpose-256x256.py b/configs/animal_2d_keypoint/rtmpose/animalpose/rtmpose-m_8xb64-210e_animalpose-256x256.py new file mode 100644 index 0000000000..b0d0fd3268 --- /dev/null +++ b/configs/animal_2d_keypoint/rtmpose/animalpose/rtmpose-m_8xb64-210e_animalpose-256x256.py @@ -0,0 +1,172 @@ +_base_ = ['../../../_base_/default_runtime.py'] + +# runtime +max_epochs = 210 +base_lr = 5e-4 +num_workers = 8 + +train_cfg = dict(max_epochs=max_epochs, val_interval=10) +randomness = dict(seed=23) + +# optimizer +optim_wrapper = dict(optimizer=dict( + type='Adam', + lr=5e-4, +)) + +# learning rate +batch_size = 64 +param_scheduler = [ + dict( + type='LinearLR', begin=0, end=500, start_factor=0.001, + by_epoch=False), # warm-up + dict( + type='MultiStepLR', + begin=0, + end=210, + milestones=[170, 200], + gamma=0.1, + by_epoch=True) +] + +# automatically scaling LR based on the actual training batch size +auto_scale_lr = dict(base_batch_size=batch_size) + +# hooks +default_hooks = dict( + checkpoint=dict(save_best='animalpose/AP', rule='greater')) + +# codec settings +codec = dict( + type='SimCCLabel', + input_size=(256, 256), + sigma=(5.66, 5.66), + simcc_split_ratio=2.0, + normalize=False, + use_dark=False) + +# model settings +model = dict( + type='TopdownPoseEstimator', + data_preprocessor=dict( + type='PoseDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True), + backbone=dict( + _scope_='mmdet', + type='CSPNeXt', + arch='P5', + expand_ratio=0.5, + deepen_factor=0.67, + widen_factor=0.75, + out_indices=(4, ), + channel_attention=True, + norm_cfg=dict(type='SyncBN'), + act_cfg=dict(type='SiLU'), + init_cfg=dict( + type='Pretrained', + prefix='backbone.', + checkpoint='https://download.openmmlab.com/mmpose/v1/projects/' + 'rtmpose/cspnext-l_udp-aic-coco_210e-256x192-273b7631_20230130.pth' + )), + head=dict( + type='RTMCCHead', + in_channels=768, + out_channels=20, + input_size=codec['input_size'], + in_featuremap_size=(8, 8), + simcc_split_ratio=codec['simcc_split_ratio'], + final_layer_kernel_size=7, + gau_cfg=dict( + hidden_dims=256, + s=128, + expansion_factor=2, + dropout_rate=0., + drop_path=0., + act_fn='SiLU', + use_rel_bias=False, + pos_enc=False), + loss=dict( + type='KLDiscretLoss', + use_target_weight=True, + beta=10., + label_softmax=True), + decoder=codec), + test_cfg=dict(flip_test=True, )) + +# base dataset settings +dataset_type = 'AnimalPoseDataset' +data_mode = 'topdown' +data_root = 'data/animalpose/' + +backend_args = dict(backend='local') + +# pipelines +train_pipeline = [ + dict(type='LoadImage'), + dict(type='GetBBoxCenterScale'), + dict(type='RandomFlip', direction='horizontal'), + dict(type='RandomHalfBody'), + dict(type='RandomBBoxTransform'), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='GenerateTarget', encoder=codec), + dict(type='PackPoseInputs') +] +val_pipeline = [ + dict(type='LoadImage'), + dict(type='GetBBoxCenterScale'), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='PackPoseInputs') +] + +# data loaders +train_dataloader = dict( + batch_size=batch_size, + num_workers=num_workers, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='annotations/animalpose_train.json', + data_prefix=dict(img=''), + pipeline=train_pipeline, + )) +val_dataloader = dict( + batch_size=batch_size, + num_workers=num_workers, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='annotations/animalpose_val.json', + data_prefix=dict(img=''), + test_mode=True, + pipeline=val_pipeline, + )) +test_dataloader = dict( + batch_size=batch_size, + num_workers=num_workers, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='annotations/animalpose_test.json', + data_prefix=dict(img=''), + test_mode=True, + pipeline=val_pipeline, + )) + +# evaluators +val_evaluator = dict( + type='CocoMetric', ann_file=data_root + 'annotations/animalpose_val.json') +test_evaluator = dict( + type='CocoMetric', ann_file=data_root + 'annotations/animalpose_test.json') diff --git a/configs/animal_2d_keypoint/rtmpose/animalpose/rtmpose-s_8xb64-210e_animalpose-256x256.py b/configs/animal_2d_keypoint/rtmpose/animalpose/rtmpose-s_8xb64-210e_animalpose-256x256.py new file mode 100644 index 0000000000..88d8850a18 --- /dev/null +++ b/configs/animal_2d_keypoint/rtmpose/animalpose/rtmpose-s_8xb64-210e_animalpose-256x256.py @@ -0,0 +1,172 @@ +_base_ = ['../../../_base_/default_runtime.py'] + +# runtime +max_epochs = 210 +base_lr = 5e-4 +num_workers = 8 + +train_cfg = dict(max_epochs=max_epochs, val_interval=10) +randomness = dict(seed=23) + +# optimizer +optim_wrapper = dict(optimizer=dict( + type='Adam', + lr=5e-4, +)) + +# learning rate +batch_size = 64 +param_scheduler = [ + dict( + type='LinearLR', begin=0, end=500, start_factor=0.001, + by_epoch=False), # warm-up + dict( + type='MultiStepLR', + begin=0, + end=210, + milestones=[170, 200], + gamma=0.1, + by_epoch=True) +] + +# automatically scaling LR based on the actual training batch size +auto_scale_lr = dict(base_batch_size=batch_size) + +# hooks +default_hooks = dict( + checkpoint=dict(save_best='animalpose/AP', rule='greater')) + +# codec settings +codec = dict( + type='SimCCLabel', + input_size=(256, 256), + sigma=(5.66, 5.66), + simcc_split_ratio=2.0, + normalize=False, + use_dark=False) + +# model settings +model = dict( + type='TopdownPoseEstimator', + data_preprocessor=dict( + type='PoseDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True), + backbone=dict( + _scope_='mmdet', + type='CSPNeXt', + arch='P5', + expand_ratio=0.5, + deepen_factor=0.33, + widen_factor=0.5, + out_indices=(4, ), + channel_attention=True, + norm_cfg=dict(type='SyncBN'), + act_cfg=dict(type='SiLU'), + init_cfg=dict( + type='Pretrained', + prefix='backbone.', + checkpoint='https://download.openmmlab.com/mmpose/v1/projects/' + 'rtmpose/cspnext-s_udp-aic-coco_210e-256x192-92f5a029_20230130.pth' + )), + head=dict( + type='RTMCCHead', + in_channels=512, + out_channels=20, + input_size=codec['input_size'], + in_featuremap_size=(8, 8), + simcc_split_ratio=codec['simcc_split_ratio'], + final_layer_kernel_size=7, + gau_cfg=dict( + hidden_dims=256, + s=128, + expansion_factor=2, + dropout_rate=0., + drop_path=0., + act_fn='SiLU', + use_rel_bias=False, + pos_enc=False), + loss=dict( + type='KLDiscretLoss', + use_target_weight=True, + beta=10., + label_softmax=True), + decoder=codec), + test_cfg=dict(flip_test=True)) + +# base dataset settings +dataset_type = 'AnimalPoseDataset' +data_mode = 'topdown' +data_root = 'data/animalpose/' + +backend_args = dict(backend='local') + +# pipelines +train_pipeline = [ + dict(type='LoadImage'), + dict(type='GetBBoxCenterScale'), + dict(type='RandomFlip', direction='horizontal'), + dict(type='RandomHalfBody'), + dict(type='RandomBBoxTransform'), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='GenerateTarget', encoder=codec), + dict(type='PackPoseInputs') +] +val_pipeline = [ + dict(type='LoadImage'), + dict(type='GetBBoxCenterScale'), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='PackPoseInputs') +] + +# data loaders +train_dataloader = dict( + batch_size=batch_size, + num_workers=num_workers, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='annotations/animalpose_train.json', + data_prefix=dict(img=''), + pipeline=train_pipeline, + )) +val_dataloader = dict( + batch_size=batch_size, + num_workers=num_workers, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='annotations/animalpose_val.json', + data_prefix=dict(img=''), + test_mode=True, + pipeline=val_pipeline, + )) +test_dataloader = dict( + batch_size=batch_size, + num_workers=num_workers, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='annotations/animalpose_test.json', + data_prefix=dict(img=''), + test_mode=True, + pipeline=val_pipeline, + )) + +# evaluators +val_evaluator = dict( + type='CocoMetric', ann_file=data_root + 'annotations/animalpose_val.json') +test_evaluator = dict( + type='CocoMetric', ann_file=data_root + 'annotations/animalpose_test.json') diff --git a/configs/animal_2d_keypoint/rtmpose/animalpose/rtmpose-t_8xb64-210e_animalpose-256x256.py b/configs/animal_2d_keypoint/rtmpose/animalpose/rtmpose-t_8xb64-210e_animalpose-256x256.py new file mode 100644 index 0000000000..36b7326511 --- /dev/null +++ b/configs/animal_2d_keypoint/rtmpose/animalpose/rtmpose-t_8xb64-210e_animalpose-256x256.py @@ -0,0 +1,172 @@ +_base_ = ['../../../_base_/default_runtime.py'] + +# runtime +max_epochs = 210 +base_lr = 5e-4 +num_workers = 8 + +train_cfg = dict(max_epochs=max_epochs, val_interval=10) +randomness = dict(seed=23) + +# optimizer +optim_wrapper = dict(optimizer=dict( + type='Adam', + lr=5e-4, +)) + +# learning rate +batch_size = 64 +param_scheduler = [ + dict( + type='LinearLR', begin=0, end=500, start_factor=0.001, + by_epoch=False), # warm-up + dict( + type='MultiStepLR', + begin=0, + end=210, + milestones=[170, 200], + gamma=0.1, + by_epoch=True) +] + +# automatically scaling LR based on the actual training batch size +auto_scale_lr = dict(base_batch_size=batch_size) + +# hooks +default_hooks = dict( + checkpoint=dict(save_best='animalpose/AP', rule='greater')) + +# codec settings +codec = dict( + type='SimCCLabel', + input_size=(256, 256), + sigma=(5.66, 5.66), + simcc_split_ratio=2.0, + normalize=False, + use_dark=False) + +# model settings +model = dict( + type='TopdownPoseEstimator', + data_preprocessor=dict( + type='PoseDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True), + backbone=dict( + _scope_='mmdet', + type='CSPNeXt', + arch='P5', + expand_ratio=0.5, + deepen_factor=0.167, + widen_factor=0.375, + out_indices=(4, ), + channel_attention=True, + norm_cfg=dict(type='SyncBN'), + act_cfg=dict(type='SiLU'), + init_cfg=dict( + type='Pretrained', + prefix='backbone.', + checkpoint='https://download.openmmlab.com/mmpose/v1/projects/' + 'rtmpose/cspnext-l_udp-aic-coco_210e-256x192-273b7631_20230130.pth' + )), + head=dict( + type='RTMCCHead', + in_channels=384, + out_channels=20, + input_size=codec['input_size'], + in_featuremap_size=(8, 8), + simcc_split_ratio=codec['simcc_split_ratio'], + final_layer_kernel_size=7, + gau_cfg=dict( + hidden_dims=256, + s=128, + expansion_factor=2, + dropout_rate=0., + drop_path=0., + act_fn='SiLU', + use_rel_bias=False, + pos_enc=False), + loss=dict( + type='KLDiscretLoss', + use_target_weight=True, + beta=10., + label_softmax=True), + decoder=codec), + test_cfg=dict(flip_test=True)) + +# base dataset settings +dataset_type = 'AnimalPoseDataset' +data_mode = 'topdown' +data_root = 'data/animalpose/' + +backend_args = dict(backend='local') + +# pipelines +train_pipeline = [ + dict(type='LoadImage'), + dict(type='GetBBoxCenterScale'), + dict(type='RandomFlip', direction='horizontal'), + dict(type='RandomHalfBody'), + dict(type='RandomBBoxTransform'), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='GenerateTarget', encoder=codec), + dict(type='PackPoseInputs') +] +val_pipeline = [ + dict(type='LoadImage'), + dict(type='GetBBoxCenterScale'), + dict(type='TopdownAffine', input_size=codec['input_size']), + dict(type='PackPoseInputs') +] + +# data loaders +train_dataloader = dict( + batch_size=batch_size, + num_workers=num_workers, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='annotations/animalpose_train.json', + data_prefix=dict(img=''), + pipeline=train_pipeline, + )) +val_dataloader = dict( + batch_size=batch_size, + num_workers=num_workers, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='annotations/animalpose_val.json', + data_prefix=dict(img=''), + test_mode=True, + pipeline=val_pipeline, + )) +test_dataloader = dict( + batch_size=batch_size, + num_workers=num_workers, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False, round_up=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_mode=data_mode, + ann_file='annotations/animalpose_test.json', + data_prefix=dict(img=''), + test_mode=True, + pipeline=val_pipeline, + )) + +# evaluators +val_evaluator = dict( + type='CocoMetric', ann_file=data_root + 'annotations/animalpose_val.json') +test_evaluator = dict( + type='CocoMetric', ann_file=data_root + 'annotations/animalpose_test.json') diff --git a/configs/animal_2d_keypoint/rtmpose/animalpose/rtmpose_animalpose.md b/configs/animal_2d_keypoint/rtmpose/animalpose/rtmpose_animalpose.md new file mode 100644 index 0000000000..b8a0f85b6d --- /dev/null +++ b/configs/animal_2d_keypoint/rtmpose/animalpose/rtmpose_animalpose.md @@ -0,0 +1,46 @@ + + +
+RTMPose (arXiv'2023) + +```bibtex +@misc{https://doi.org/10.48550/arxiv.2303.07399, + doi = {10.48550/ARXIV.2303.07399}, + url = {https://arxiv.org/abs/2303.07399}, + author = {Jiang, Tao and Lu, Peng and Zhang, Li and Ma, Ningsheng and Han, Rui and Lyu, Chengqi and Li, Yining and Chen, Kai}, + keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences}, + title = {RTMPose: Real-Time Multi-Person Pose Estimation based on MMPose}, + publisher = {arXiv}, + year = {2023}, + copyright = {Creative Commons Attribution 4.0 International} +} + +``` + +
+ + + +
+Animal-Pose (ICCV'2019) + +```bibtex +@InProceedings{Cao_2019_ICCV, + author = {Cao, Jinkun and Tang, Hongyang and Fang, Hao-Shu and Shen, Xiaoyong and Lu, Cewu and Tai, Yu-Wing}, + title = {Cross-Domain Adaptation for Animal Pose Estimation}, + booktitle = {The IEEE International Conference on Computer Vision (ICCV)}, + month = {October}, + year = {2019} +} +``` + +
+ +Results on AnimalPose validation set + +| Arch | Input Size | AP | AP50 | AP75 | AR | AR50 | ckpt | log | +| :----------------------------------------------------------------------------------------------------- | :--------: | :---: | :-------------: | :-------------: | :---: | :-------------: | :--------: | :-------: | +| [rtmpose-t](/configs/animal_2d_keypoint/rtmpose/animalpose/rtmpose-t_8xb64-210e_animalpose-256x256.py) | 256x256 | 0.680 | 0.927 | 0.770 | 0.718 | 0.934 | [ckpt](<>) | [log](<>) | +| [rtmpose-s](/configs/animal_2d_keypoint/rtmpose/animalpose/rtmpose-s_8xb64-210e_animalpose-256x256.py) | 256x256 | 0.709 | 0.938 | 0.799 | 0.748 | 0.946 | [ckpt](<>) | [log](<>) | +| [rtmpose-m](/configs/animal_2d_keypoint/rtmpose/animalpose/rtmpose-m_8xb64-210e_animalpose-256x256.py) | 256x256 | 0.598 | 0.896 | 0.653 | 0.642 | 0.900 | [ckpt](<>) | [log](<>) | +| [rtmpose-l](/configs/animal_2d_keypoint/rtmpose/animalpose/rtmpose-l_8xb64-210e_animalpose-256x256.py) | 256x256 | 0.766 | 0.959 | 0.855 | 0.800 | 0.968 | [ckpt](<>) | [log](<>) |