diff --git a/mmeval/metrics/__init__.py b/mmeval/metrics/__init__.py index 9ae21aef..252bfe13 100644 --- a/mmeval/metrics/__init__.py +++ b/mmeval/metrics/__init__.py @@ -5,6 +5,7 @@ from .accuracy import Accuracy from .ava_map import AVAMeanAP from .bleu import BLEU +from .cityscapes_detection import CityScapesDetection from .coco_detection import COCODetection from .connectivity_error import ConnectivityError from .dota_map import DOTAMeanAP @@ -44,7 +45,7 @@ 'SumAbsoluteDifferences', 'GradientError', 'MattingMeanSquaredError', 'ConnectivityError', 'ROUGE', 'Perplexity', 'KeypointEndPointError', 'KeypointAUC', 'KeypointNME', 'NaturalImageQualityEvaluator', - 'WordAccuracy' + 'WordAccuracy', 'CityScapesDetection' ] _deprecated_msg = ( diff --git a/mmeval/metrics/cityscapes_detection.py b/mmeval/metrics/cityscapes_detection.py new file mode 100644 index 00000000..6811c4d8 --- /dev/null +++ b/mmeval/metrics/cityscapes_detection.py @@ -0,0 +1,318 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import itertools +import os +import os.path as osp +import tempfile +import warnings +from collections import OrderedDict +from terminaltables import AsciiTable +from typing import Dict, Optional, Sequence, Union + +from mmeval.core.base_metric import BaseMetric + +try: + import cityscapesscripts.evaluation.evalInstanceLevelSemanticLabeling as CSEval # noqa: E501 + import cityscapesscripts.helpers.labels as CSLabels + + from .utils.cityscapes_wrapper import evaluateImgLists + HAS_CITYSCAPESAPI = True +except ImportError: + HAS_CITYSCAPESAPI = False + +try: + from mmcv import imwrite +except ImportError: + from mmeval.utils import imwrite + + +class CityScapesDetection(BaseMetric): + """CityScapes metric for instance segmentation. + + Args: + outfile_prefix (str): The prefix of txt and png files. It is the + saving path of txt and png file, e.g. "a/b/prefix". + If not specified, a temp file will be created. + It should be specified when format_only is True. Defaults to None. + seg_prefix (str, optional): Path to the directory which contains the + cityscapes instance segmentation masks. It's necessary when + training and validation. It could be None when infer on test + dataset. Defaults to None. + format_only (bool): Format the output results without performing + evaluation. It is useful when you want to format the result + to a specific format and submit it to the test server. + Defaults to False. + classwise (bool): Whether to return the computed results of each + class. Defaults to False. + dump_matches (bool): Whether dump matches.json file during evaluating. + Defaults to False. + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + **kwargs: Keyword parameters passed to :class:`BaseMetric`. + + Examples: + >>> import numpy as np + >>> import os + >>> import os.path as osp + >>> import tempfile + >>> from PIL import Image + >>> from mmeval import CityScapesDetection + >>> + >>> tmp_dir = tempfile.TemporaryDirectory() + >>> dataset_metas = { + ... 'classes': ('person', 'rider', 'car', 'truck', 'bus', 'train', + ... 'motorcycle', 'bicycle') + ... } + >>> seg_prefix = osp.join(tmp_dir.name, 'cityscapes', 'gtFine', 'val') + >>> os.makedirs(seg_prefix, exist_ok=True) + >>> cityscapes_det_metric = CityScapesDetection( + ... dataset_meta=dataset_metas, + ... seg_prefix=seg_prefix, + ... classwise=True) + >>> + >>> def _gen_fake_datasamples(seg_prefix): + ... city = 'lindau' + ... os.makedirs(osp.join(seg_prefix, city), exist_ok=True) + ... + ... sequenceNb = '000000' + ... frameNb1 = '000019' + ... img_name1 = f'{city}_{sequenceNb}_{frameNb1}_gtFine_instanceIds.png' + ... img_path1 = osp.join(seg_prefix, city, img_name1) + ... basename1 = osp.splitext(osp.basename(img_path1))[0] + ... masks1 = np.zeros((20, 20), dtype=np.int32) + ... masks1[:10, :10] = 24 * 1000 + ... Image.fromarray(masks1).save(img_path1) + ... + ... dummy_mask1 = np.zeros((1, 20, 20), dtype=np.uint8) + ... dummy_mask1[:, :10, :10] = 1 + ... prediction = { + ... 'basename': basename1, + ... 'mask_scores': np.array([1.0]), + ... 'labels': np.array([0]), + ... 'masks': dummy_mask1 + ... } + ... groundtruth = { + ... 'file_name': img_path1 + ... } + ... + ... return [prediction], [groundtruth] + >>> + >>> predictions, groundtruths = _gen_fake_datasamples(seg_prefix) + >>> cityscapes_det_metric(predictions, groundtruths) # doctest: +ELLIPSIS # noqa: E501 + {'mAP': ..., 'AP50': ...} + >>> tmp_dir.cleanup() + """ + + def __init__(self, + outfile_prefix: Optional[str] = None, + seg_prefix: Optional[str] = None, + format_only: bool = False, + classwise: bool = False, + dump_matches: bool = False, + backend_args: Optional[dict] = None, + **kwargs): + + if not HAS_CITYSCAPESAPI: + raise RuntimeError('Failed to import `cityscapesscripts`.' + 'Please try to install official ' + 'cityscapesscripts by ' + '"pip install cityscapesscripts"') + super().__init__(**kwargs) + + self.tmp_dir = None + self.format_only = format_only + if self.format_only: + assert outfile_prefix is not None, 'outfile_prefix must be not' + 'None when format_only is True, otherwise the result files will' + 'be saved to a temp directory which will be cleaned up at the end.' + else: + assert seg_prefix is not None, '`seg_prefix` is necessary when ' + 'computing the CityScapes metrics' + + # outfile_prefix should be a prefix of a path which points to a shared + # storage when train or test with multi nodes. + self.outfile_prefix = outfile_prefix + if outfile_prefix is None: + self.tmp_dir = tempfile.TemporaryDirectory() + self.outfile_prefix = osp.join(self.tmp_dir.name, 'results') + else: + # the directory to save predicted panoptic segmentation mask + self.outfile_prefix = osp.join(self.outfile_prefix, 'results') # type: ignore # yapf: disable # noqa: E501 + # make dir to avoid potential error + dir_name = osp.expanduser(self.outfile_prefix) + os.makedirs(dir_name, exist_ok=True) + + self.seg_prefix = seg_prefix + self.classwise = classwise + self.backend_args = backend_args + self.dump_matches = dump_matches + + def add(self, predictions: Sequence[Dict], groundtruths: Sequence[Dict]) -> None: # type: ignore # yapf: disable # noqa: E501 + """Add the intermediate results to `self._results`. + + Args: + predictions (Sequence[dict]): A sequence of dict. Each dict + representing a detection result for an image, with the + following keys: + + - basename (str): The image name. + - masks (numpy.ndarray): Shape (N, H, W), the predicted masks. + - labels (numpy.ndarray): Shape (N, ), the predicted labels + of bounding boxes. + - mask_scores (np.array, optional): Shape (N, ), the predicted + scores of masks. + + groundtruths (Sequence[dict]): A sequence of dict. If load from + `ann_file`, the dict inside can be empty. Else, each dict + represents a groundtruths for an image, with the following + keys: + + - file_name (str): The absolute path of groundtruth image. + """ + for prediction, groundtruth in zip(predictions, groundtruths): + assert isinstance(prediction, dict), 'The prediciton should be ' \ + f'a sequence of dict, but got a sequence of {type(prediction)}.' # noqa: E501 + assert isinstance(groundtruth, dict), 'The label should be ' \ + f'a sequence of dict, but got a sequence of {type(groundtruth)}.' # noqa: E501 + prediction = self._process_prediction(prediction) + self._results.append((prediction, groundtruth)) + + def compute_metric(self, results: list) -> Dict[str, Union[float, list]]: + """Compute the CityScapes metrics. + + Args: + results (List[tuple]): A list of tuple. Each tuple is the + prediction and ground truth of an image. This list has already + been synced across all ranks. + + Returns: + dict: The computed metric. The keys are the names of + the metrics, and the values are corresponding results. + """ + eval_results: OrderedDict = OrderedDict() + table_results: OrderedDict = OrderedDict() + if self.format_only: + self.logger.info('Results are saved in ' # type: ignore + f'{osp.dirname(self.outfile_prefix)}') # type: ignore # yapf: disable # noqa: E501 + return eval_results + + gt_instances_file = osp.join(self.outfile_prefix, 'gtInstances.json') # type: ignore # yapf: disable # noqa: E501 + # split gt and prediction list + preds, gts = zip(*results) + CSEval.args.JSONOutput = False + CSEval.args.colorized = False + CSEval.args.gtInstancesFile = gt_instances_file + + groundtruth_list = [gt['file_name'] for gt in gts] + prediction_list = [pred['pred_txt'] for pred in preds] + CSEval_results = evaluateImgLists( + prediction_list, + groundtruth_list, + CSEval.args, + self.backend_args, + dump_matches=self.dump_matches)['averages'] + + map = float(CSEval_results['allAp']) + map_50 = float(CSEval_results['allAp50%']) + + eval_results['mAP'] = map + eval_results['AP50'] = map_50 + + results_list = ('mAP', f'{round(map * 100, 2):0.2f}', + f'{round(map_50 * 100, 2):0.2f}') + + if self.classwise: + results_per_category = [] + for category, aps in CSEval_results['classes'].items(): + eval_results[f'{category}_ap'] = float(aps['ap']) + eval_results[f'{category}_ap50'] = float(aps['ap50%']) + results_per_category.append( + (f'{category}', f'{round(float(aps["ap"]) * 100, 2):0.2f}', + f'{round(float(aps["ap50%"]) * 100, 2):0.2f}')) + results_per_category.append(results_list) + table_results['results_list'] = results_per_category + else: + table_results['results_list'] = [results_list] + + self._print_results(table_results) + return eval_results + + def __del__(self) -> None: + """Clean up the results if necessary.""" + if self.tmp_dir is not None: + self.tmp_dir.cleanup() + + def _process_prediction(self, prediction: dict) -> dict: + """Process prediction. + + Args: + prediction (Sequence[dict]): A sequence of dict. Each dict + representing a detection result for an image, with the + following keys: + + - basename (str): The image name. + - masks (numpy.ndarray): Shape (N, H, W), the predicted masks. + - labels (numpy.ndarray): Shape (N, ), the predicted labels + of bounding boxes. + - mask_scores (np.array, optional): Shape (N, ), the predicted + scores of masks. + + Returns: + dict: The processed prediction results. With key `pred_txt`. + """ + classes = self.classes + pred = dict() + basename = prediction['basename'] + + pred_txt = osp.join(self.outfile_prefix, basename + '_pred.txt') # type: ignore # yapf: disable # noqa: E501 + pred['pred_txt'] = pred_txt + + labels = prediction['labels'] + masks = prediction['masks'] + mask_scores = prediction['mask_scores'] + with open(pred_txt, 'w') as f: + for i, (label, mask, + mask_score) in enumerate(zip(labels, masks, mask_scores)): + class_name = classes[label] + class_id = CSLabels.name2label[class_name].id + png_filename = osp.join( + self.outfile_prefix, basename + f'_{i}_{class_name}.png') # type: ignore # yapf: disable # noqa: E501 + imwrite(mask, png_filename) + f.write(f'{osp.basename(png_filename)} ' + f'{class_id} {mask_score}\n') + return pred + + @property + def classes(self) -> tuple: + """Get classes from self.dataset_meta.""" + if self.dataset_meta and 'classes' in self.dataset_meta: + classes = self.dataset_meta['classes'] + elif self.dataset_meta and 'CLASSES' in self.dataset_meta: + classes = self.dataset_meta['CLASSES'] + warnings.warn( + 'DeprecationWarning: The `CLASSES` in `dataset_meta` is ' + 'deprecated, use `classes` instead!') + else: + raise RuntimeError('Could not find `classes` in dataset_meta: ' + f'{self.dataset_meta}') + return classes + + def _print_results(self, table_results: dict) -> None: + """Print the evaluation results table. + + Args: + table_results (dict): The computed metric. + """ + result = table_results['results_list'] + + header = ['class', 'AP(%)', 'AP50(%)'] + table_title = ' Cityscapes Results' + + results_flatten = list(itertools.chain(*result)) + + results_2d = itertools.zip_longest( + *[results_flatten[i::3] for i in range(3)]) + table_data = [header] + table_data += [result for result in results_2d] + table = AsciiTable(table_data, title=table_title) + table.inner_footing_row_border = True + self.logger.info(f'CityScapes Evaluation Results: \n {table.table}') diff --git a/mmeval/metrics/utils/cityscapes_wrapper.py b/mmeval/metrics/utils/cityscapes_wrapper.py new file mode 100644 index 00000000..7dd22a85 --- /dev/null +++ b/mmeval/metrics/utils/cityscapes_wrapper.py @@ -0,0 +1,281 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright (c) https://github.com/mcordts/cityscapesScripts +# A wrapper of `cityscapesscripts` which supports loading groundtruth +# image from `backend_args`. +import cityscapesscripts.evaluation.evalInstanceLevelSemanticLabeling as CSEval # noqa: E501 +import io +import json +import numpy as np +import os +import sys +from cityscapesscripts.evaluation.instance import Instance +from cityscapesscripts.helpers.csHelpers import id2label # noqa: E501 +from cityscapesscripts.helpers.csHelpers import labels, writeDict2JSON +from pathlib import Path +from PIL import Image +from typing import Optional, Union + +from mmeval.fileio import get + + +def evaluateImgLists(prediction_list: list, + groundtruth_list: list, + args: CSEval.CArgs, + backend_args: Optional[dict] = None, + dump_matches: bool = False) -> dict: + """A wrapper of obj:``cityscapesscripts.evaluation. + evalInstanceLevelSemanticLabeling.evaluateImgLists``. Support loading + groundtruth image from file backend. + + Args: + prediction_list (list): A list of prediction txt file. + groundtruth_list (list): A list of groundtruth image file. + args (CSEval.CArgs): A global object setting in + obj:``cityscapesscripts.evaluation. + evalInstanceLevelSemanticLabeling`` + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + dump_matches (bool): whether dump matches.json. Defaults to False. + + Returns: + dict: The computed metric. + """ + # determine labels of interest + CSEval.setInstanceLabels(args) + # get dictionary of all ground truth instances + gt_instances = getGtInstances( + groundtruth_list, args, backend_args=backend_args) + # match predictions and ground truth + matches = matchGtWithPreds(prediction_list, groundtruth_list, gt_instances, + args, backend_args) + if dump_matches: + CSEval.writeDict2JSON(matches, 'matches.json') + # evaluate matches + apScores = CSEval.evaluateMatches(matches, args) + # averages + avgDict = CSEval.computeAverages(apScores, args) + # result dict + resDict = CSEval.prepareJSONDataForResults(avgDict, apScores, args) + if args.JSONOutput: + # create output folder if necessary + path = os.path.dirname(args.exportFile) + CSEval.ensurePath(path) + # Write APs to JSON + CSEval.writeDict2JSON(resDict, args.exportFile) + + CSEval.printResults(avgDict, args) + + return resDict + + +def matchGtWithPreds(prediction_list: list, + groundtruth_list: list, + gt_instances: dict, + args: CSEval.CArgs, + backend_args=None): + """A wrapper of obj:``cityscapesscripts.evaluation. + evalInstanceLevelSemanticLabeling.matchGtWithPreds``. Support loading + groundtruth image from file backend. + + Args: + prediction_list (list): A list of prediction txt file. + groundtruth_list (list): A list of groundtruth image file. + gt_instances (dict): Groundtruth dict. + args (CSEval.CArgs): A global object setting in + obj:``cityscapesscripts.evaluation. + evalInstanceLevelSemanticLabeling`` + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + + Returns: + dict: The processed prediction and groundtruth result. + """ + matches: dict = dict() + if not args.quiet: + print(f'Matching {len(prediction_list)} pairs of images...') + + count = 0 + for (pred, gt) in zip(prediction_list, groundtruth_list): + # Read input files + gt_image = readGTImage(gt, backend_args) + pred_info = readPredInfo(pred) + + # Get and filter ground truth instances + unfiltered_instances = gt_instances[gt] + cur_gt_instances_orig = CSEval.filterGtInstances( + unfiltered_instances, args) + + # Try to assign all predictions + (cur_gt_instances, + cur_pred_instances) = CSEval.assignGt2Preds(cur_gt_instances_orig, + gt_image, pred_info, args) + + # append to global dict + matches[gt] = {} + matches[gt]['groundTruth'] = cur_gt_instances + matches[gt]['prediction'] = cur_pred_instances + + count += 1 + if not args.quiet: + print(f'\rImages Processed: {count}', end=' ') + sys.stdout.flush() + + if not args.quiet: + print('') + + return matches + + +def readGTImage(image_file: Union[str, Path], + backend_args: Optional[dict] = None) -> np.ndarray: + """Read an image from path. Same as obj:``cityscapesscripts.evaluation. + evalInstanceLevelSemanticLabeling.readGTImage``, but support loading + groundtruth image from file backend. + + Args: + image_file (str or Path): Either a str or pathlib.Path. + backend_args (dict, optional): Instantiates the corresponding file + backend. It may contain `backend` key to specify the file + backend. If it contains, the file backend corresponding to this + value will be used and initialized with the remaining values, + otherwise the corresponding file backend will be selected + based on the prefix of the file path. Defaults to None. + + Returns: + np.ndarray: The groundtruth image. + """ + img_bytes = get(image_file, backend_args=backend_args) + with io.BytesIO(img_bytes) as buff: + img = Image.open(buff) + img = np.array(img) + return img + + +def readPredInfo(prediction_file: str) -> dict: + """A wrapper of obj:``cityscapesscripts.evaluation. + evalInstanceLevelSemanticLabeling.readPredInfo``. + + Args: + prediction_file (str): The prediction txt file. + + Returns: + dict: The processed prediction results. + """ + + printError = CSEval.printError + + predInfo = {} + if (not os.path.isfile(prediction_file)): + printError(f"Infofile '{prediction_file}' " + 'for the predictions not found.') + with open(prediction_file) as f: + for line in f: + splittedLine = line.split(' ') + if len(splittedLine) != 3: + printError('Invalid prediction file. Expected content: ' + 'relPathPrediction1 labelIDPrediction1 ' + 'confidencePrediction1') + if os.path.isabs(splittedLine[0]): + printError('Invalid prediction file. First entry in each ' + 'line must be a relative path.') + + filename = os.path.join( + os.path.dirname(prediction_file), splittedLine[0]) + + imageInfo = {} + imageInfo['labelID'] = int(float(splittedLine[1])) + imageInfo['conf'] = float(splittedLine[2]) # type: ignore + predInfo[filename] = imageInfo + + return predInfo + + +def getGtInstances(groundtruth_list: list, + args: CSEval.CArgs, + backend_args: Optional[dict] = None) -> dict: + """A wrapper of obj:``cityscapesscripts.evaluation. + evalInstanceLevelSemanticLabeling.getGtInstances``. Support loading + groundtruth image from file backend. + + Args: + groundtruth_list (list): A list of groundtruth image file. + args (CSEval.CArgs): A global object setting in + obj:``cityscapesscripts.evaluation. + evalInstanceLevelSemanticLabeling`` + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + + Returns: + dict: The computed metric. + """ + # if there is a global statistics json, then load it + if (os.path.isfile(args.gtInstancesFile)): + if not args.quiet: + print('Loading ground truth instances from JSON.') + with open(args.gtInstancesFile) as json_file: + gt_instances = json.load(json_file) + # otherwise create it + else: + if (not args.quiet): + print('Creating ground truth instances from png files.') + gt_instances = instances2dict( + groundtruth_list, args, backend_args=backend_args) + writeDict2JSON(gt_instances, args.gtInstancesFile) + + return gt_instances + + +def instances2dict(image_list: list, + args: CSEval.CArgs, + backend_args: Optional[dict] = None) -> dict: + """A wrapper of obj:``cityscapesscripts.evaluation. + evalInstanceLevelSemanticLabeling.instances2dict``. Support loading + groundtruth image from file backend. + + Args: + image_list (list): A list of image file. + args (CSEval.CArgs): A global object setting in + obj:``cityscapesscripts.evaluation. + evalInstanceLevelSemanticLabeling`` + backend_args (dict, optional): Arguments to instantiate the + preifx of uri corresponding backend. Defaults to None. + + Returns: + dict: The processed groundtruth results. + """ + imgCount = 0 + instanceDict = {} + + if not isinstance(image_list, list): + image_list = [image_list] + + if not args.quiet: + print(f'Processing {len(image_list)} images...') + + for image_name in image_list: + # Load image + img_bytes = get(image_name, backend_args=backend_args) + with io.BytesIO(img_bytes) as buff: + img = Image.open(buff) + imgNp = np.array(img) + + # Initialize label categories + instances: dict = {} + for label in labels: + instances[label.name] = [] + + # Loop through all instance ids in instance image + for instanceId in np.unique(imgNp): + instanceObj = Instance(imgNp, instanceId) + + instances[id2label[instanceObj.labelID].name].append( + instanceObj.toDict()) + + instanceDict[image_name] = instances + imgCount += 1 + + if not args.quiet: + print(f'\rImages Processed: {imgCount}', end=' ') + sys.stdout.flush() + + return instanceDict diff --git a/mmeval/utils/__init__.py b/mmeval/utils/__init__.py index fc999bac..97fb4bc1 100644 --- a/mmeval/utils/__init__.py +++ b/mmeval/utils/__init__.py @@ -1,10 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. - +from .image_io import imread, imwrite from .logging import DEFAULT_LOGGER from .misc import has_method, is_list_of, is_seq_of, is_tuple_of, try_import from .path import is_filepath __all__ = [ 'try_import', 'has_method', 'is_seq_of', 'is_list_of', 'is_tuple_of', - 'is_filepath', 'DEFAULT_LOGGER' + 'is_filepath', 'DEFAULT_LOGGER', 'imread', 'imwrite' ] diff --git a/mmeval/utils/image_io.py b/mmeval/utils/image_io.py new file mode 100644 index 00000000..8c22fde1 --- /dev/null +++ b/mmeval/utils/image_io.py @@ -0,0 +1,99 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import os +import os.path as osp +from pathlib import Path +from typing import TYPE_CHECKING, Optional, Union + +from mmeval.utils.misc import try_import +from mmeval.utils.path import is_filepath + +if TYPE_CHECKING: + import cv2 +else: + cv2 = try_import('cv2') + + +def imwrite(img: np.ndarray, file_path: str) -> bool: + """Write image to local path. + + Args: + img (np.ndarray): Image array to be written. + file_path (str): Image file path. + + Returns: + bool: Successful or not. + """ + if cv2 is None: + raise ImportError('To use `imwrite` function, ' + 'please install opencv-python first.') + + assert is_filepath(file_path) + + # auto make dir + dir_name = osp.expanduser(osp.dirname(file_path)) + os.makedirs(dir_name, exist_ok=True) + + img_ext = osp.splitext(file_path)[-1] + file_path = str(file_path) + # Encode image according to image suffix. + # For example, if image path is '/path/your/img.jpg', the encode + # format is '.jpg'. + flag, img_buff = cv2.imencode(img_ext, img) + with open(file_path, 'wb') as f: + f.write(img_buff) + return flag + + +def imread(file_path: Union[str, Path], + flag: str = 'color', + channel_order: str = 'rgb', + backend_args: Optional[dict] = None) -> np.ndarray: + """Read an image from path. + + Args: + file_path (str or Path): Either a str or pathlib.Path. + flag (str): Flags specifying the color type of a loaded image, + candidates are `color`, `grayscale`, `unchanged`, + `color_ignore_orientation` and `grayscale_ignore_orientation`. + Defaults to 'color'. + channel_order (str): Order of channel, candidates are `bgr` and `rgb`. + Defaults to 'rgb'. + backend_args (dict, optional): Instantiates the corresponding file + backend. It may contain `backend` key to specify the file + backend. If it contains, the file backend corresponding to this + value will be used and initialized with the remaining values, + otherwise the corresponding file backend will be selected + based on the prefix of the file path. Defaults to None. + + Returns: + ndarray: Loaded image array. + """ + if cv2 is None: + raise ImportError('To use `imread` function, ' + 'please install opencv-python first.') + + imread_flags = { + 'color': + cv2.IMREAD_COLOR, + 'grayscale': + cv2.IMREAD_GRAYSCALE, + 'unchanged': + cv2.IMREAD_UNCHANGED, + 'color_ignore_orientation': + cv2.IMREAD_IGNORE_ORIENTATION | cv2.IMREAD_COLOR, + 'grayscale_ignore_orientation': + cv2.IMREAD_IGNORE_ORIENTATION | cv2.IMREAD_GRAYSCALE + } + + from mmeval.fileio import get + + assert is_filepath(file_path) + + img_bytes = get(file_path, backend_args=backend_args) + img_np = np.frombuffer(img_bytes, np.uint8) + flag = imread_flags[flag] if isinstance(flag, str) else flag + img = cv2.imdecode(img_np, flag) + if flag == cv2.IMREAD_COLOR and channel_order == 'rgb': + cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img) + return img diff --git a/requirements/optional.txt b/requirements/optional.txt index 1bd0bbd5..503aa7f6 100644 --- a/requirements/optional.txt +++ b/requirements/optional.txt @@ -1,4 +1,6 @@ +cityscapesscripts opencv-python!=4.5.5.62,!=4.5.5.64 pycocotools scipy shapely +terminaltables diff --git a/tests/test_metrics/test_cityscapes_detection.py b/tests/test_metrics/test_cityscapes_detection.py new file mode 100644 index 00000000..7c2b4c47 --- /dev/null +++ b/tests/test_metrics/test_cityscapes_detection.py @@ -0,0 +1,130 @@ +import math +import numpy as np +import os +import os.path as osp +import pytest +import tempfile +from PIL import Image + +from mmeval.metrics import CityScapesDetection +from mmeval.utils import try_import + +cityscapesscripts = try_import('cityscapesscripts') + + +def _gen_fake_datasamples(seg_prefix): + city = 'lindau' + os.makedirs(osp.join(seg_prefix, city), exist_ok=True) + + sequenceNb = '000000' + frameNb1 = '000019' + img_name1 = f'{city}_{sequenceNb}_{frameNb1}_gtFine_instanceIds.png' + img_path1 = osp.join(seg_prefix, city, img_name1) + basename1 = osp.splitext(osp.basename(img_path1))[0] + masks1 = np.zeros((20, 20), dtype=np.int32) + masks1[:10, :10] = 24 * 1000 + Image.fromarray(masks1).save(img_path1) + + dummy_mask1 = np.zeros((1, 20, 20), dtype=np.uint8) + dummy_mask1[:, :10, :10] = 1 + prediction1 = { + 'basename': basename1, + 'mask_scores': np.array([1.0]), + 'labels': np.array([0]), + 'masks': dummy_mask1 + } + groundtruth1 = {'file_name': img_path1} + + frameNb2 = '000020' + img_name2 = f'{city}_{sequenceNb}_{frameNb2}_gtFine_instanceIds.png' + img_path2 = osp.join(seg_prefix, city, img_name2) + basename2 = osp.splitext(osp.basename(img_path2))[0] + masks2 = np.zeros((20, 20), dtype=np.int32) + masks2[:10, :10] = 24 * 1000 + 1 + Image.fromarray(masks2).save(img_path2) + + dummy_mask2 = np.zeros((1, 20, 20), dtype=np.uint8) + dummy_mask2[:, :10, :10] = 1 + prediction2 = { + 'basename': basename2, + 'mask_scores': np.array([0.98]), + 'labels': np.array([1]), + 'masks': dummy_mask1 + } + groundtruth2 = {'file_name': img_path2} + return [prediction1, prediction2], [groundtruth1, groundtruth2] + + +@pytest.mark.skipif( + cityscapesscripts is None, reason='cityscapesscriptsr is not available!') +def test_metric_invalid_usage(): + with pytest.raises(AssertionError): + CityScapesDetection( + outfile_prefix='tmp/cityscapes/results', seg_prefix=None) + + with pytest.raises(AssertionError): + CityScapesDetection(outfile_prefix=None, format_only=True) + + +@pytest.mark.skipif( + cityscapesscripts is None, reason='cityscapesscriptsr is not available!') +def test_compute_metric(): + tmp_dir = tempfile.TemporaryDirectory() + + dataset_metas = { + 'classes': ('person', 'rider', 'car', 'truck', 'bus', 'train', + 'motorcycle', 'bicycle') + } + + # create dummy data + seg_prefix = osp.join(tmp_dir.name, 'cityscapes', 'gtFine', 'val') + os.makedirs(seg_prefix, exist_ok=True) + predictions, groundtruths = _gen_fake_datasamples(seg_prefix) + + # test single cityscapes dataset evaluation + cityscapes_det_metric = CityScapesDetection( + dataset_meta=dataset_metas, + outfile_prefix=osp.join(tmp_dir.name, 'test'), + seg_prefix=seg_prefix, + classwise=False) + + eval_results = cityscapes_det_metric( + predictions=predictions, groundtruths=groundtruths) + target = {'mAP': 0.5, 'AP50': 0.5} + assert eval_results == target + + # test classwise result evaluation + cityscapes_det_metric = CityScapesDetection( + dataset_meta=dataset_metas, + outfile_prefix=osp.join(tmp_dir.name, 'test'), + seg_prefix=seg_prefix, + classwise=True) + + eval_results = cityscapes_det_metric( + predictions=predictions, groundtruths=groundtruths) + mAP = eval_results.pop('mAP') + AP50 = eval_results.pop('AP50') + person_ap = eval_results.pop('person_ap') + person_ap50 = eval_results.pop('person_ap50') + assert mAP == 0.5 + assert AP50 == 0.5 + assert person_ap == 0.5 + assert person_ap50 == 0.5 + # others classes ap or ap50 should be nan + for v in eval_results.values(): + assert math.isnan(v) + + # test format only evaluation + cityscapes_det_metric = CityScapesDetection( + dataset_meta=dataset_metas, + format_only=True, + outfile_prefix=osp.join(tmp_dir.name, 'test'), + seg_prefix=seg_prefix, + classwise=True) + + eval_results = cityscapes_det_metric( + predictions=predictions, groundtruths=groundtruths) + assert osp.exists(f'{osp.join(tmp_dir.name, "test")}') + assert eval_results == dict() + + tmp_dir.cleanup()