diff --git a/mmeval/metrics/__init__.py b/mmeval/metrics/__init__.py index 0411f1ee..b6e7d136 100644 --- a/mmeval/metrics/__init__.py +++ b/mmeval/metrics/__init__.py @@ -17,6 +17,7 @@ from .keypoint_auc import KeypointAUC from .keypoint_epe import KeypointEndPointError from .keypoint_nme import KeypointNME +from .lvis_detection import LVISDetection from .mae import MeanAbsoluteError from .matting_mse import MattingMeanSquaredError from .mean_iou import MeanIoU @@ -48,7 +49,7 @@ 'KeypointAUC', 'KeypointNME', 'NaturalImageQualityEvaluator', 'WordAccuracy', 'PrecisionRecallF1score', 'SingleLabelPrecisionRecallF1score', 'MultiLabelPrecisionRecallF1score', - 'CharRecallPrecision' + 'LVISDetection', 'CharRecallPrecision' ] _deprecated_msg = ( diff --git a/mmeval/metrics/lvis_detection.py b/mmeval/metrics/lvis_detection.py new file mode 100644 index 00000000..8e216b29 --- /dev/null +++ b/mmeval/metrics/lvis_detection.py @@ -0,0 +1,375 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import contextlib +import io +import itertools +import numpy as np +import os.path as osp +import tempfile +from collections import OrderedDict +from rich.console import Console +from rich.table import Table +from typing import Dict, List, Optional, Sequence, Union + +from mmeval.fileio import get_local_path +from .coco_detection import COCODetection + +try: + from lvis import LVIS, LVISEval, LVISResults + HAS_LVISAPI = True +except ImportError: + HAS_LVISAPI = False + + +class LVISDetection(COCODetection): + """LVIS evaluation metric. + + Evaluate AR, AP for detection tasks on LVIS dataset including proposal/box + detection and instance segmentation. + + Args: + ann_file (str): Path to the LVIS dataset annotation file. + metric (str | List[str]): Metrics to be evaluated. Valid metrics + include 'bbox', 'segm', and 'proposal'. Defaults to 'bbox'. + iou_thrs (float | List[float], optional): IoU threshold to compute AP + and AR. If not specified, IoUs from 0.5 to 0.95 will be used. + Defaults to None. + classwise (bool): Whether to return the computed results of each + class. Defaults to False. + proposal_nums (int): Numbers of proposals to be evaluated. + Defaults to 300. + metric_items (List[str], optional): Metric result names to be + recorded in the evaluation result. If None, default configurations + in LVIS will be used.Defaults to None. + format_only (bool): Format the output results without performing + evaluation. It is useful when you want to format the result + to a specific format and submit it to the test server. + Defaults to False. + outfile_prefix (str, optional): The prefix of json files. It includes + the file path and the prefix of filename, e.g., "a/b/prefix". + If not specified, a temp file will be created. Defaults to None. + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + print_results (bool): Whether to print the results. Defaults to True. + logger (Logger, optional): logger used to record messages. When set to + ``None``, the default logger will be used. + Defaults to None. + **kwargs: Keyword parameters passed to :class:`BaseMetric`. + + Examples: + >>> import numpy as np + >>> from mmeval import LVISDetection + >>> try: + >>> from mmeval.metrics.utils.coco_wrapper import mask_util + >>> except ImportError as e: + >>> mask_util = None + >>> + >>> num_classes = 4 + >>> fake_dataset_metas = { + ... 'classes': tuple([str(i) for i in range(num_classes)]) + ... } + >>> + >>> lvis_det_metric = LVISDetection( + ... ann_file='data/lvis_v1/annotations/lvis_v1_train.json' + ... dataset_meta=fake_dataset_metas, + ... metric=['bbox', 'segm'] + ... ) + >>> lvis_det_metric(predictions=predictions) # doctest: +ELLIPSIS # noqa: E501 + {'bbox_AP': ..., 'bbox_AP50': ..., ..., + 'segm_AP': ..., 'segm_AP50': ..., ...,} + """ + + def __init__(self, + ann_file: str, + metric: Union[str, List[str]] = 'bbox', + classwise: bool = False, + proposal_nums: int = 300, + iou_thrs: Optional[Union[float, Sequence[float]]] = None, + metric_items: Optional[Sequence[str]] = None, + format_only: bool = False, + outfile_prefix: Optional[str] = None, + backend_args: Optional[dict] = None, + **kwargs) -> None: + if not HAS_LVISAPI: + raise RuntimeError( + 'Package lvis is not installed. Please run "pip install ' + 'git+https://github.com/lvis-dataset/lvis-api.git".') + super().__init__( + metric=metric, + classwise=classwise, + iou_thrs=iou_thrs, + metric_items=metric_items, + format_only=format_only, + outfile_prefix=outfile_prefix, + backend_args=backend_args, + **kwargs) + self.proposal_nums = proposal_nums # type: ignore + + with get_local_path( + filepath=ann_file, backend_args=backend_args) as local_path: + self._lvis_api = LVIS(local_path) + + def add_predictions(self, predictions: Sequence[Dict]) -> None: + """Add predictions to `self._results`. + + Args: + predictions (Sequence[dict]): A sequence of dict. Each dict + representing a detection result for an image, with the + following keys: + + - img_id (int): Image id. + - bboxes (numpy.ndarray): Shape (N, 4), the predicted + bounding bboxes of this image, in 'xyxy' foramrt. + - scores (numpy.ndarray): Shape (N, ), the predicted scores + of bounding boxes. + - labels (numpy.ndarray): Shape (N, ), the predicted labels + of bounding boxes. + - masks (list[RLE], optional): The predicted masks. + - mask_scores (np.array, optional): Shape (N, ), the predicted + scores of masks. + """ + self.add(predictions) + + def add(self, predictions: Sequence[Dict]) -> None: # type: ignore # yapf: disable # noqa: E501 + """Add the intermediate results to `self._results`. + + Args: + predictions (Sequence[dict]): A sequence of dict. Each dict + representing a detection result for an image, with the + following keys: + + - img_id (int): Image id. + - bboxes (numpy.ndarray): Shape (N, 4), the predicted + bounding bboxes of this image, in 'xyxy' foramrt. + - scores (numpy.ndarray): Shape (N, ), the predicted scores + of bounding boxes. + - labels (numpy.ndarray): Shape (N, ), the predicted labels + of bounding boxes. + - masks (list[RLE], optional): The predicted masks. + - mask_scores (np.array, optional): Shape (N, ), the predicted + scores of masks. + """ + for prediction in predictions: + assert isinstance(prediction, dict), 'The prediciton should be ' \ + f'a sequence of dict, but got a sequence of {type(prediction)}.' # noqa: E501 + self._results.append(prediction) + + def __call__(self, *args, **kwargs) -> Dict: + """Stateless call for a metric compute.""" + + # cache states + cache_results = self._results + cache_lvis_api = self._lvis_api + cache_cat_ids = self.cat_ids + cache_img_ids = self.img_ids + + self._results = [] + self.add(*args, **kwargs) + metric_result = self.compute_metric(self._results) + + # recover states from cache + self._results = cache_results + self._lvis_api = cache_lvis_api + self.cat_ids = cache_cat_ids + self.img_ids = cache_img_ids + + return metric_result + + def compute_metric( # type: ignore + self, results: list) -> Dict[str, Union[float, list]]: + """Compute the LVIS metrics. + + Args: + results (List[tuple]): A list of tuple. Each tuple is the + prediction and ground truth of an image. This list has already + been synced across all ranks. + + Returns: + dict: The computed metric. The keys are the names of + the metrics, and the values are corresponding results. + """ + tmp_dir = None + if self.outfile_prefix is None: + tmp_dir = tempfile.TemporaryDirectory() + outfile_prefix = osp.join(tmp_dir.name, 'results') + else: + outfile_prefix = self.outfile_prefix + + # handle lazy init + if len(self.cat_ids) == 0: + self.cat_ids = self._lvis_api.get_cat_ids() + if len(self.img_ids) == 0: + self.img_ids = self._lvis_api.get_img_ids() + + # convert predictions to coco format and dump to json file + result_files = self.results2json(results, outfile_prefix) + + eval_results: OrderedDict = OrderedDict() + table_results: OrderedDict = OrderedDict() + if self.format_only: + self.logger.info('results are saved in ' + f'{osp.dirname(outfile_prefix)}') + return eval_results + + lvis_gt = self._lvis_api + + for metric in self.metrics: + self.logger.info(f'Evaluating {metric}...') + + try: + lvis_dt = LVISResults(lvis_gt, result_files[metric]) + except IndexError: + self.logger.warning( + 'The testing results of the whole dataset is empty.') + break + + iou_type = 'bbox' if metric == 'proposal' else metric + lvis_eval = LVISEval(lvis_gt, lvis_dt, iou_type) + lvis_eval.params.imgIds = self.img_ids + metric_items = self.metric_items + if metric == 'proposal': + lvis_eval.params.max_dets = self.proposal_nums + lvis_eval.evaluate() + lvis_eval.accumulate() + lvis_eval.summarize() + if metric_items is None: + metric_items = [ + f'AR@{self.proposal_nums}', + f'ARs@{self.proposal_nums}', + f'ARm@{self.proposal_nums}', + f'ARl@{self.proposal_nums}' + ] + results_list = [] + for k, v in lvis_eval.get_results().items(): + if k in metric_items: + val = float(v) + results_list.append(f'{round(val * 100, 2):0.2f}') + eval_results[k] = val + table_results[f'{metric}_result'] = results_list + + else: + lvis_eval.evaluate() + lvis_eval.accumulate() + lvis_eval.summarize() + lvis_results = lvis_eval.get_results() + + if metric_items is None: + metric_items = [ + 'AP', 'AP50', 'AP75', 'APs', 'APm', 'APl', 'APr', + 'APc', 'APf' + ] + + results_list = [] + for metric_item, v in lvis_results.items(): + if metric_item in metric_items: + key = f'{metric}_{metric_item}' + val = float(v) + results_list.append(f'{round(val * 100, 2)}') + eval_results[key] = val + table_results[f'{metric}_result'] = results_list + + if self.classwise: # Compute per-category AP + # Compute per-category AP + # from https://github.com/facebookresearch/detectron2/ + precisions = lvis_eval.eval['precision'] + # precision: (iou, recall, cls, area range, max dets) + assert len(self.cat_ids) == precisions.shape[2] + + results_per_category = [] + for idx, catId in enumerate(self.cat_ids): + # area range index 0: all area ranges + # max dets index -1: typically 100 per image + # the dimensions of precisions are + # [num_thrs, num_recalls, num_cats, num_area_rngs] + nm = self._lvis_api.load_cats([catId])[0] + precision = precisions[:, :, idx, 0] + precision = precision[precision > -1] + if precision.size: + ap = np.mean(precision) + else: + ap = float('nan') + results_per_category.append( + (f'{nm["name"]}', f'{round(ap * 100, 2)}')) + eval_results[f'{metric}_{nm["name"]}_precision'] = ap + + table_results[f'{metric}_classwise_result'] = \ + results_per_category + # Save lvis summarize print information to logger + redirect_string = io.StringIO() + with contextlib.redirect_stdout(redirect_string): + lvis_eval.print_results() + self.logger.info('\n' + redirect_string.getvalue()) + if tmp_dir is not None: + tmp_dir.cleanup() + # if the testing results of the whole dataset is empty, + # does not print tables. + if self.print_results and len(table_results) > 0: + self._print_results(table_results) + return eval_results + + def _print_results(self, table_results: dict) -> None: + """Print the evaluation results table. + + Args: + table_results (dict): The computed metric. + """ + for metric in self.metrics: + result = table_results[f'{metric}_result'] + + if metric == 'proposal': + table_title = ' Recall Results (%)' + if self.metric_items is None: + assert len(result) == 4 + headers = [ + f'AR@{self.proposal_nums}', + f'ARs@{self.proposal_nums}', + f'ARm@{self.proposal_nums}', + f'ARl@{self.proposal_nums}' + ] + else: + assert len(result) == len(self.metric_items) # type: ignore # yapf: disable # noqa: E501 + headers = self.metric_items # type: ignore + else: + table_title = f' {metric} Results (%)' + if self.metric_items is None: + assert len(result) == 9 + headers = [ + f'{metric}_AP', f'{metric}_AP50', f'{metric}_AP75', + f'{metric}_APs', f'{metric}_APm', f'{metric}_APl', + f'{metric}_APr', f'{metric}_APc', f'{metric}_APf' + ] + else: + assert len(result) == len(self.metric_items) + headers = [ + f'{metric}_{item}' for item in self.metric_items + ] + table = Table(title=table_title) + console = Console() + for name in headers: + table.add_column(name, justify='left') + table.add_row(*result) + with console.capture() as capture: + console.print(table, end='') + self.logger.info('\n' + capture.get()) + + if self.classwise and metric != 'proposal': + self.logger.info( + f'Evaluating {metric} metric of each category...') + classwise_table_title = f' {metric} Classwise Results (%)' + classwise_result = table_results[f'{metric}_classwise_result'] + + num_columns = min(6, len(classwise_result) * 2) + results_flatten = list(itertools.chain(*classwise_result)) + headers = ['category', f'{metric}_AP'] * (num_columns // 2) + results_2d = itertools.zip_longest(*[ + results_flatten[i::num_columns] for i in range(num_columns) + ]) + + table = Table(title=classwise_table_title) + console = Console() + for name in headers: + table.add_column(name, justify='left') + for _result in results_2d: + table.add_row(*_result) + with console.capture() as capture: + console.print(table, end='') + self.logger.info('\n' + capture.get()) diff --git a/requirements/optional.txt b/requirements/optional.txt index e11968ee..18028250 100644 --- a/requirements/optional.txt +++ b/requirements/optional.txt @@ -1,2 +1,4 @@ +-e git+https://github.com/lvis-dataset/lvis-api.git#egg=lvis +opencv-python!=4.5.5.62,!=4.5.5.64 pycocotools shapely diff --git a/tests/test_metrics/test_lvis_detection_metric.py b/tests/test_metrics/test_lvis_detection_metric.py new file mode 100644 index 00000000..3c3a6052 --- /dev/null +++ b/tests/test_metrics/test_lvis_detection_metric.py @@ -0,0 +1,390 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import os.path as osp +import pytest +import tempfile +from json import dump + +from mmeval.core.base_metric import BaseMetric +from mmeval.metrics import LVISDetection +from mmeval.utils import try_import + +coco_wrapper = try_import('mmeval.metrics.utils.coco_wrapper') + + +def create_dummy_lvis_json(json_name): + dummy_mask = np.zeros((10, 10), order='F', dtype=np.uint8) + dummy_mask[:5, :5] = 1 + rle_mask = coco_wrapper.mask_util.encode(dummy_mask) + rle_mask['counts'] = rle_mask['counts'].decode('utf-8') + image = { + 'id': 0, + 'width': 640, + 'height': 640, + 'neg_category_ids': [], + 'not_exhaustive_category_ids': [], + 'coco_url': 'http://images.cocodataset.org/val2017/0.jpg', + } + + annotation_1 = { + 'id': 1, + 'image_id': 0, + 'category_id': 1, + 'area': 400, + 'bbox': [50, 60, 20, 20], + 'segmentation': rle_mask, + } + + annotation_2 = { + 'id': 2, + 'image_id': 0, + 'category_id': 1, + 'area': 900, + 'bbox': [100, 120, 30, 30], + 'segmentation': rle_mask, + } + + annotation_3 = { + 'id': 3, + 'image_id': 0, + 'category_id': 2, + 'area': 1600, + 'bbox': [150, 160, 40, 40], + 'segmentation': rle_mask, + } + + annotation_4 = { + 'id': 4, + 'image_id': 0, + 'category_id': 1, + 'area': 10000, + 'bbox': [250, 260, 100, 100], + 'segmentation': rle_mask, + } + + categories = [ + { + 'id': 1, + 'name': 'aerosol_can', + 'frequency': 'c', + 'image_count': 64 + }, + { + 'id': 2, + 'name': 'air_conditioner', + 'frequency': 'f', + 'image_count': 364 + }, + ] + + fake_json = { + 'images': [image], + 'annotations': + [annotation_1, annotation_2, annotation_3, annotation_4], + 'categories': categories + } + + with open(json_name, 'w') as f: + dump(fake_json, f) + + +def _create_dummy_results(): + # create fake results + bboxes = np.array([[50, 60, 70, 80], [100, 120, 130, 150], + [150, 160, 190, 200], [250, 260, 350, 360]]) + scores = np.array([1.0, 0.98, 0.96, 0.95]) + labels = np.array([0, 0, 1, 0]) + + mask = np.zeros((10, 10), dtype=np.uint8) + mask[:5, :5] = 1 + + dummy_mask = [ + coco_wrapper.mask_util.encode( + np.array(mask[:, :, np.newaxis], order='F', dtype='uint8'))[0] + for _ in range(4) + ] + return dict( + img_id=0, + bboxes=bboxes, + scores=scores, + labels=labels, + masks=dummy_mask) + + +# TODO: move necessary function to somewhere +def _gen_bboxes(num_bboxes, img_w=256, img_h=256): + # random generate bounding boxes in 'xyxy' formart. + x = np.random.rand(num_bboxes, ) * img_w + y = np.random.rand(num_bboxes, ) * img_h + w = np.random.rand(num_bboxes, ) * (img_w - x) + h = np.random.rand(num_bboxes, ) * (img_h - y) + return np.stack([x, y, x + w, y + h], axis=1) + + +def _gen_masks(bboxes, img_w=256, img_h=256): + # random generate masks + if coco_wrapper is None: + raise ImportError('Please try to install official pycocotools by ' + '"pip install pycocotools"') + masks = [] + for i, bbox in enumerate(bboxes): + mask = np.zeros((img_h, img_w)) + bbox = bbox.astype(np.int32) + box_mask = (np.random.rand(bbox[3] - bbox[1], bbox[2] - bbox[0]) > + 0.3).astype(np.int32) + mask[bbox[1]:bbox[3], bbox[0]:bbox[2]] = box_mask + masks.append( + coco_wrapper.mask_util.encode( + np.array(mask[:, :, np.newaxis], order='F', + dtype='uint8'))[0]) # encoded with RLE + return masks + + +def _gen_prediction(num_pred=10, + num_classes=2, + img_w=256, + img_h=256, + img_id=0, + with_mask=False): + # random create prediction + pred_boxes = _gen_bboxes(num_bboxes=num_pred, img_w=img_w, img_h=img_h) + prediction = { + 'img_id': img_id, + 'bboxes': pred_boxes, + 'scores': np.random.rand(num_pred, ), + 'labels': np.random.randint(0, num_classes, size=(num_pred, )) + } + if with_mask: + pred_masks = _gen_masks(bboxes=pred_boxes, img_w=img_w, img_h=img_h) + prediction['masks'] = pred_masks + return prediction + + +def _gen_groundtruth(num_gt=10, + num_classes=2, + img_w=256, + img_h=256, + img_id=0, + with_mask=False): + # random create prediction + gt_boxes = _gen_bboxes(num_bboxes=num_gt, img_w=img_w, img_h=img_h) + groundtruth = { + 'img_id': img_id, + 'width': img_w, + 'height': img_h, + 'neg_category_ids': [], + 'bboxes': gt_boxes, + 'labels': np.random.randint(0, num_classes, size=(num_gt, )), + 'ignore_flags': np.zeros(num_gt) + } + if with_mask: + pred_masks = _gen_masks(bboxes=gt_boxes, img_w=img_w, img_h=img_h) + groundtruth['masks'] = pred_masks + return groundtruth + + +@pytest.mark.skipif( + coco_wrapper is None, reason='coco_wrapper is not available!') +@pytest.mark.parametrize( + argnames='metric_kwargs', + argvalues=[ + {}, + { + 'iou_thrs': [0.5, 0.75] + }, + { + 'classwise': True + }, + { + 'metric_items': ['AP', 'AP50'] + }, + { + 'proposal_nums': 30 + }, + ]) +def test_box_metric_interface(metric_kwargs): + tmp_dir = tempfile.TemporaryDirectory() + + # create dummy data + fake_json_file = osp.join(tmp_dir.name, 'fake_data.json') + create_dummy_lvis_json(fake_json_file) + + num_classes = 2 + metric = ['bbox'] + # Avoid some potential error + fake_dataset_metas = { + 'classes': tuple([str(i) for i in range(num_classes)]) + } + lvis_det_metric = LVISDetection( + ann_file=fake_json_file, + metric=metric, + dataset_meta=fake_dataset_metas, + **metric_kwargs) + assert isinstance(lvis_det_metric, BaseMetric) + + metric_results = lvis_det_metric(predictions=[_create_dummy_results()]) + assert isinstance(metric_results, dict) + assert 'bbox_AP' in metric_results + + +@pytest.mark.skipif( + coco_wrapper is None, reason='coco_wrapper is not available!') +@pytest.mark.parametrize( + argnames='metric_kwargs', + argvalues=[ + {}, + { + 'iou_thrs': [0.5, 0.75] + }, + { + 'classwise': True + }, + { + 'metric_items': ['AP', 'AP50'] + }, + { + 'proposal_nums': 30 + }, + ]) +def test_segm_metric_interface(metric_kwargs): + tmp_dir = tempfile.TemporaryDirectory() + + # create dummy data + fake_json_file = osp.join(tmp_dir.name, 'fake_data.json') + create_dummy_lvis_json(fake_json_file) + + num_classes = 2 + metric = ['segm'] + # Avoid some potential error + fake_dataset_metas = { + 'classes': tuple([str(i) for i in range(num_classes)]) + } + lvis_det_metric = LVISDetection( + ann_file=fake_json_file, + metric=metric, + dataset_meta=fake_dataset_metas, + **metric_kwargs) + assert isinstance(lvis_det_metric, BaseMetric) + + metric_results = lvis_det_metric(predictions=[_create_dummy_results()]) + assert isinstance(metric_results, dict) + assert 'segm_AP' in metric_results + tmp_dir.cleanup() + + +@pytest.mark.skipif( + coco_wrapper is None, reason='coco_wrapper is not available!') +def test_metric_invalid_usage(): + tmp_dir = tempfile.TemporaryDirectory() + + # create dummy data + fake_json_file = osp.join(tmp_dir.name, 'fake_data.json') + create_dummy_lvis_json(fake_json_file) + + with pytest.raises(KeyError): + LVISDetection(ann_file=fake_json_file, metric='xxx') + + with pytest.raises(TypeError): + LVISDetection(ann_file=fake_json_file, iou_thrs=1) + + with pytest.raises(AssertionError): + LVISDetection(ann_file=fake_json_file, format_only=True) + + num_classes = 2 + # Avoid some potential error + fake_dataset_metas = { + 'classes': tuple([str(i) for i in range(num_classes)]) + } + lvis_det_metric = LVISDetection( + ann_file=fake_json_file, dataset_meta=fake_dataset_metas) + + with pytest.raises(KeyError): + prediction = _gen_prediction(num_classes=num_classes) + del prediction['bboxes'] + lvis_det_metric([prediction]) + + with pytest.raises(AssertionError): + prediction = _gen_prediction(num_classes=num_classes) + lvis_det_metric(prediction) + tmp_dir.cleanup() + + +@pytest.mark.skipif( + coco_wrapper is None, reason='coco_wrapper is not available!') +def test_compute_metric(): + tmp_dir = tempfile.TemporaryDirectory() + + # create dummy data + fake_json_file = osp.join(tmp_dir.name, 'fake_data.json') + create_dummy_lvis_json(fake_json_file) + dummy_pred = _create_dummy_results() + fake_dataset_metas = dict(classes=['car', 'bicycle']) + + # test single coco dataset evaluation + lvis_det_metric = LVISDetection( + ann_file=fake_json_file, + classwise=False, + outfile_prefix=f'{tmp_dir.name}/test', + dataset_meta=fake_dataset_metas) + eval_results = lvis_det_metric([dummy_pred]) + target = { + 'bbox_AP': 1.0, + 'bbox_AP50': 1.0, + 'bbox_AP75': 1.0, + 'bbox_APc': 1.0, + 'bbox_APf': 1.0, + 'bbox_APr': -1.0, + 'bbox_APs': 1.0, + 'bbox_APm': 1.0, + 'bbox_APl': 1.0, + } + + results = {k: round(v, 4) for k, v in eval_results.items()} + assert results == target + assert osp.isfile(osp.join(tmp_dir.name, 'test.bbox.json')) + + # test box and segm coco dataset evaluation + lvis_det_metric = LVISDetection( + ann_file=fake_json_file, + classwise=False, + metric=['bbox', 'segm'], + outfile_prefix=f'{tmp_dir.name}/test', + dataset_meta=fake_dataset_metas) + eval_results = lvis_det_metric([dummy_pred]) + target = { + 'bbox_AP': 1.0, + 'bbox_AP50': 1.0, + 'bbox_AP75': 1.0, + 'bbox_APc': 1.0, + 'bbox_APf': 1.0, + 'bbox_APr': -1.0, + 'bbox_APs': 1.0, + 'bbox_APm': 1.0, + 'bbox_APl': 1.0, + 'segm_AP': 1.0, + 'segm_AP50': 1.0, + 'segm_AP75': 1.0, + 'segm_APc': 1.0, + 'segm_APf': 1.0, + 'segm_APr': -1.0, + 'segm_APs': 1.0, + 'segm_APm': 1.0, + 'segm_APl': 1.0, + } + + results = {k: round(v, 4) for k, v in eval_results.items()} + assert results == target + assert osp.isfile(osp.join(tmp_dir.name, 'test.bbox.json')) + assert osp.isfile(osp.join(tmp_dir.name, 'test.segm.json')) + + # test format only evaluation + lvis_det_metric = LVISDetection( + ann_file=fake_json_file, + classwise=False, + format_only=True, + outfile_prefix=f'{tmp_dir.name}/test', + dataset_meta=fake_dataset_metas) + eval_results = lvis_det_metric([dummy_pred]) + assert osp.exists(f'{tmp_dir.name}/test.bbox.json') + assert eval_results == dict() + tmp_dir.cleanup()