-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[add] ConcatDataset.evaluate feature #1932
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,116 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import numpy as np | ||
from mmcv.utils import print_log | ||
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset | ||
|
||
from .builder import DATASETS | ||
|
||
|
||
@DATASETS.register_module() | ||
class ConcatDataset(_ConcatDataset): | ||
"""A wrapper of concatenated dataset. | ||
|
||
Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but | ||
concat the group flag for image aspect ratio. | ||
|
||
Args: | ||
datasets (list[:obj:`Dataset`]): A list of datasets. | ||
separate_eval (bool): Whether to evaluate the results | ||
separately if it is used as validation dataset. | ||
Defaults to True. | ||
""" | ||
|
||
def __init__(self, datasets, separate_eval=True): | ||
super(ConcatDataset, self).__init__(datasets) | ||
self.separate_eval = separate_eval | ||
|
||
def clip_by_index(self, results, start_idx, end_idx): | ||
clip_res = dict() | ||
clip_res['preds'] = results[0]['preds'][start_idx:end_idx, :, :] | ||
clip_res['boxes'] = results[0]['boxes'][start_idx:end_idx, :] | ||
clip_res['image_paths'] = results[0]['image_paths'][start_idx:end_idx] | ||
clip_res['bbox_ids'] = results[0]['bbox_ids'][start_idx:end_idx] | ||
return [clip_res] | ||
|
||
def rebuild_results(self, results): | ||
new_results = list() | ||
new_res = results[0] | ||
new_results.append(new_res) | ||
for i in range(1, len(results)): | ||
res = results[i] | ||
new_res['preds'] = np.append(new_res['preds'], res['preds'], axis=0) | ||
new_res['boxes'] = np.append(new_res['boxes'], res['boxes'], axis=0) | ||
new_res['image_paths'] += res['image_paths'] | ||
new_res['bbox_ids'] += res['bbox_ids'] | ||
return new_results, len(new_res['image_paths']) | ||
|
||
def update_total_eval_results_ap(self, eval_ap_list, total_eval_results): | ||
sum_num = 0 | ||
sum_precision = 0 | ||
for eval_ap in eval_ap_list: | ||
numbers = eval_ap['numbers'] | ||
precisions = numbers * eval_ap['AP'] | ||
sum_num += numbers | ||
sum_precision += precisions | ||
total_eval_results.update({'AP': sum_precision/sum_num}) | ||
|
||
|
||
def evaluate(self, results, logger=None, **kwargs): | ||
"""Evaluate the results. | ||
|
||
Args: | ||
results (list[list | tuple]): Testing results of the dataset. | ||
logger (logging.Logger | str | None): Logger used for printing | ||
related information during evaluation. Default: None. | ||
|
||
Returns: | ||
dict[str: float]: AP results of the total dataset or each separate | ||
dataset if `self.separate_eval=True`. | ||
""" | ||
|
||
new_results, res_len = self.rebuild_results(results) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please explain the reason for concatenating all elements into There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. results是多个val数据集混在一起的,然后每batch_size个输出组成results中一个dict,我需要根据cumulative_sizes记录的索引分割不同val数据集对应的results,分别计算各自的evaluate,我觉得先合并成一个results,然后分割比较方便,不知道我的理解对不对 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 谢谢,我疑惑的点是有没有可能直接在 |
||
assert res_len == self.cumulative_sizes[-1], \ | ||
('Dataset and results have different sizes: ' | ||
f'{self.cumulative_sizes[-1]} v.s. {len(results)}') | ||
|
||
# Check whether all the datasets support evaluation | ||
for dataset in self.datasets: | ||
assert hasattr(dataset, 'evaluate'), \ | ||
f'{type(dataset)} does not implement evaluate function' | ||
|
||
if self.separate_eval: | ||
total_eval_results = dict() | ||
eval_ap_list = list() | ||
for dataset_idx, dataset in enumerate(self.datasets): | ||
start_idx = 0 if dataset_idx == 0 else \ | ||
self.cumulative_sizes[dataset_idx - 1] | ||
end_idx = self.cumulative_sizes[dataset_idx] | ||
|
||
results_per_dataset = self.clip_by_index(new_results, start_idx, end_idx) | ||
print_log( | ||
f'Evaluateing dataset-{dataset_idx} with ' | ||
f'{len(results_per_dataset)} images now', | ||
logger=logger) | ||
|
||
eval_results_per_dataset = dataset.evaluate( | ||
results_per_dataset, logger=logger, **kwargs) | ||
for k, v in eval_results_per_dataset.items(): | ||
total_eval_results.update({f'{dataset_idx}_{k}': v}) | ||
|
||
if 'AP' in eval_results_per_dataset.keys(): | ||
eval_ap_list.append(dict(numbers=end_idx-start_idx, AP=eval_results_per_dataset['AP'])) | ||
|
||
if len(eval_ap_list) > 0: | ||
self.update_total_eval_results_ap(eval_ap_list, total_eval_results) | ||
return total_eval_results | ||
elif len(set([type(ds) for ds in self.datasets])) != 1: | ||
raise NotImplementedError( | ||
'All the datasets should have same types') | ||
Comment on lines
+106
to
+108
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to perform this check if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 确实不需要,我从mmdet那拷的代码没看仔细 😄 |
||
else: | ||
raise NotImplementedError( | ||
'separate_eval=False not implemented') | ||
|
||
|
||
@DATASETS.register_module() | ||
class RepeatDataset: | ||
"""A wrapper of repeated dataset. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bottom-up models may not generate predictions with key
bbox_ids
, so a key check may be needed hereThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get