Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add Panoptic Quality (PQ) #408

Draft
wants to merge 43 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
e2b0adb
First draft
Dec 24, 2022
cc29a21
More improvements
Dec 24, 2022
65136a8
Update features
Dec 24, 2022
970653d
Remove feature
Dec 24, 2022
67123ca
Add print statements
Dec 24, 2022
9c97456
Use Image feature
Dec 24, 2022
4d2fdc7
Update feature
Dec 24, 2022
3d43d5b
Fix code
Dec 24, 2022
3e6a0db
More fixes
Dec 24, 2022
12fe7ee
Fix more code
Dec 24, 2022
439c2e3
Add print statement
Dec 24, 2022
abdfb3e
Add print statement
Dec 24, 2022
d683c2f
Add more fixes
Dec 24, 2022
c179f18
Add image_id to all annotations
Dec 24, 2022
16f4aaa
Add image_id to all annotations
Dec 24, 2022
de90b26
First draft
Dec 25, 2022
fac3bb5
Update features
Dec 25, 2022
c5173f0
Update features
Dec 25, 2022
059b652
Update features
Dec 25, 2022
c77e326
Debug
Dec 25, 2022
e53249f
Add feature
Dec 25, 2022
ac5827f
Update features
Jan 15, 2023
d15b9ae
Improve implementation
Jan 15, 2023
3044365
Disable features
Jan 15, 2023
49d8205
Add features
Jan 15, 2023
0546a84
Improve features
Jan 15, 2023
b7a7d2b
Improve features
Jan 15, 2023
1511fb2
Debug features
Jan 15, 2023
aeccea0
Debug features
Jan 15, 2023
0ccc704
Debug features
Jan 15, 2023
d591d98
Debug features
Jan 15, 2023
9fd134e
Remove segments_info
Jan 15, 2023
0d0b47c
Debug
Jan 15, 2023
e9ca360
Debug
Jan 15, 2023
678c8ea
Debug
Jan 15, 2023
2f0ef02
Debug
Jan 23, 2023
f485dc7
Add features back
Jan 23, 2023
095ebfc
Add print statement
Jan 23, 2023
297ebbd
Debug
Jan 23, 2023
ff776c2
Debug
Jan 23, 2023
a76ca5d
Debug
Jan 23, 2023
cb1464e
Clean up code
Jan 30, 2023
8c6e29b
Fix style
Jan 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
398 changes: 398 additions & 0 deletions metrics/panoptic_quality/panoptic_quality.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,398 @@
# Copyright 2022 The HuggingFace Evaluate Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Panoptic Quality (PQ) metric.

Entirely based on https://github.com/cocodataset/panopticapi/blob/master/panopticapi/evaluation.py.
"""

from collections import defaultdict
import functools
import json
import multiprocessing
import io
import os
import traceback

import numpy as np
from PIL import Image

import datasets
import evaluate


_DESCRIPTION = """...
"""

_KWARGS_DESCRIPTION = """
Args:
predictions (`List[ndarray]`):
List of predicted segmentation maps, each of shape (height, width). Each segmentation map can be of a different size.
references (`List[ndarray]`):
List of ground truth segmentation maps, each of shape (height, width). Each segmentation map can be of a different size.
predicted_annotations (`List[ndarray]`):
List of predicted annotations (segments info).
reference_annotations (`List[ndarray]`):
List of reference annotations (segments info).
output_dir (`str`):
Path to the output directory.
categories (`dict`):
Dictionary mapping category IDs to something like {'name': 'wall', 'id': 0, 'isthing': 0, 'color': [120, 120, 120]}.
Example here: https://github.com/cocodataset/panopticapi/blob/master/panoptic_coco_categories.json.

Returns:
`Dict[str, float | ndarray]` comprising various elements:
- *panoptic_quality* (`float`):
Panoptic quality score.
...

Examples:

>>> import numpy as np

>>> panoptic_quality = evaluate.load("panoptic_quality")

>>> # TODO
"""

_CITATION = """..."""


# The decorator is used to prints an error trhown inside process
def get_traceback(f):
@functools.wraps(f)
def wrapper(*args, **kwargs):
try:
return f(*args, **kwargs)
except Exception as e:
print("Caught exception in worker thread:")
traceback.print_exc()
raise e

return wrapper


def rgb2id(color):
if isinstance(color, np.ndarray) and len(color.shape) == 3:
if color.dtype == np.uint8:
color = color.astype(np.int32)
return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
return int(color[0] + 256 * color[1] + 256 * 256 * color[2])


def id2rgb(id_map):
if isinstance(id_map, np.ndarray):
id_map_copy = id_map.copy()
rgb_shape = tuple(list(id_map.shape) + [3])
rgb_map = np.zeros(rgb_shape, dtype=np.uint8)
for i in range(3):
rgb_map[..., i] = id_map_copy % 256
id_map_copy //= 256
return rgb_map
color = []
for _ in range(3):
color.append(id_map % 256)
id_map //= 256
return color


OFFSET = 256 * 256 * 256
VOID = 0


class PQStatCat:
def __init__(self):
self.iou = 0.0
self.tp = 0
self.fp = 0
self.fn = 0

def __iadd__(self, pq_stat_cat):
self.iou += pq_stat_cat.iou
self.tp += pq_stat_cat.tp
self.fp += pq_stat_cat.fp
self.fn += pq_stat_cat.fn
return self


class PQStat:
def __init__(self):
self.pq_per_cat = defaultdict(PQStatCat)

def __getitem__(self, i):
return self.pq_per_cat[i]

def __iadd__(self, pq_stat):
for label, pq_stat_cat in pq_stat.pq_per_cat.items():
self.pq_per_cat[label] += pq_stat_cat
return self

def pq_average(self, categories, isthing):
pq, sq, rq, n = 0, 0, 0, 0
per_class_results = {}
for label, label_info in categories.items():
if isthing is not None:
cat_isthing = label_info["isthing"] == 1
if isthing != cat_isthing:
continue
iou = self.pq_per_cat[label].iou
tp = self.pq_per_cat[label].tp
fp = self.pq_per_cat[label].fp
fn = self.pq_per_cat[label].fn
if tp + fp + fn == 0:
per_class_results[label] = {"pq": 0.0, "sq": 0.0, "rq": 0.0}
continue
n += 1
pq_class = iou / (tp + 0.5 * fp + 0.5 * fn)
sq_class = iou / tp if tp != 0 else 0
rq_class = tp / (tp + 0.5 * fp + 0.5 * fn)
per_class_results[label] = {"pq": pq_class, "sq": sq_class, "rq": rq_class}
pq += pq_class
sq += sq_class
rq += rq_class

return {"pq": pq / n, "sq": sq / n, "rq": rq / n, "n": n}, per_class_results


@get_traceback
def pq_compute_single_core(proc_id, annotation_set, predictions, references, categories):
print("Annotation set:", annotation_set)

pq_stat = PQStat()

idx = 0
for pan_pred, pan_gt, (pred_ann, gt_ann) in zip(predictions, references, annotation_set):
if idx % 100 == 0:
print("Core: {}, {} from {} images processed".format(proc_id, idx, len(annotation_set)))
idx += 1

# we go from RGB space to id space here
# pan_gt = np.array(Image.open(os.path.join(gt_folder, gt_ann["file_name"])), dtype=np.uint32)
pan_gt = rgb2id(np.array(pan_gt))
# pan_pred = np.array(Image.open(os.path.join(pred_folder, pred_ann["file_name"])), dtype=np.uint32)
pan_pred = rgb2id(np.array(pan_pred))

print("Ground truth annotation: ", gt_ann)
print("Predicted annotation: ", pred_ann)

# gt_segms = {el["id"]: el for el in gt_ann}
# pred_segms = {el["id"]: el for el in pred_ann}

gt_segms = {id: {k: v[idx] for k, v in gt_ann.items()} for idx, id in enumerate(gt_ann["id"])}
pred_segms = {id: {k: v[idx] for k, v in pred_ann.items()} for idx, id in enumerate(pred_ann["id"])}

print("Ground truth segments:", gt_segms)
print("Predicted segments:", pred_segms)

# predicted segments area calculation + prediction sanity checks
# pred_labels_set = set(el["id"] for el in pred_ann)
pred_labels_set = set(pred_ann["id"])
labels, labels_cnt = np.unique(pan_pred, return_counts=True)

print("Predicted labels set:", pred_labels_set)
print("Labels:", labels)
print("Labels count:", labels_cnt)

print("Predicted segments:", pred_segms.keys())

for label, label_cnt in zip(labels, labels_cnt):
print("Label:", label)
if label not in pred_segms:
print(f"Label {label} not in predicted segments {pred_segms.keys()}")
if label == VOID:
continue
# raise KeyError(
# "In the image with ID {} segment with ID {} is presented in PNG and not presented in JSON.".format(
# gt_ann["image_id"], label
# )
# )
raise KeyError("The segment with ID {} is presented in PNG and not presented in JSON.".format(label))
pred_segms[label]["area"] = label_cnt
pred_labels_set.remove(label)
if pred_segms[label]["category_id"] not in categories:
# raise KeyError(
# "In the image with ID {} segment with ID {} has unknown category_id {}.".format(
# gt_ann["image_id"], label, pred_segms[label]["category_id"]
# )
# )
raise KeyError(
"The segment with ID {} has unknown category_id {}.".format(
label, pred_segms[label]["category_id"]
)
)
if len(pred_labels_set) != 0:
# raise KeyError(
# "In the image with ID {} the following segment IDs {} are presented in JSON and not presented in PNG.".format(
# gt_ann["image_id"], list(pred_labels_set)
# )
# )
raise KeyError(
"The following segment IDs {} are presented in JSON and not presented in PNG.".format(
list(pred_labels_set)
)
)

# confusion matrix calculation
pan_gt_pred = pan_gt.astype(np.uint64) * OFFSET + pan_pred.astype(np.uint64)
gt_pred_map = {}
labels, labels_cnt = np.unique(pan_gt_pred, return_counts=True)
for label, intersection in zip(labels, labels_cnt):
gt_id = label // OFFSET
pred_id = label % OFFSET
gt_pred_map[(gt_id, pred_id)] = intersection

# count all matched pairs
gt_matched = set()
pred_matched = set()
for label_tuple, intersection in gt_pred_map.items():
gt_label, pred_label = label_tuple
if gt_label not in gt_segms:
continue
if pred_label not in pred_segms:
continue
if gt_segms[gt_label]["iscrowd"] == 1:
continue
if gt_segms[gt_label]["category_id"] != pred_segms[pred_label]["category_id"]:
continue

union = (
pred_segms[pred_label]["area"]
+ gt_segms[gt_label]["area"]
- intersection
- gt_pred_map.get((VOID, pred_label), 0)
)
iou = intersection / union
if iou > 0.5:
pq_stat[gt_segms[gt_label]["category_id"]].tp += 1
pq_stat[gt_segms[gt_label]["category_id"]].iou += iou
gt_matched.add(gt_label)
pred_matched.add(pred_label)

# count false positives
crowd_labels_dict = {}
for gt_label, gt_info in gt_segms.items():
if gt_label in gt_matched:
continue
# crowd segments are ignored
if gt_info["iscrowd"] == 1:
crowd_labels_dict[gt_info["category_id"]] = gt_label
continue
pq_stat[gt_info["category_id"]].fn += 1

# count false positives
for pred_label, pred_info in pred_segms.items():
if pred_label in pred_matched:
continue
# intersection of the segment with VOID
intersection = gt_pred_map.get((VOID, pred_label), 0)
# plus intersection with corresponding CROWD region if it exists
if pred_info["category_id"] in crowd_labels_dict:
intersection += gt_pred_map.get((crowd_labels_dict[pred_info["category_id"]], pred_label), 0)
# predicted segment is ignored if more than half of the segment correspond to VOID and CROWD regions
if intersection / pred_info["area"] > 0.5:
continue
pq_stat[pred_info["category_id"]].fp += 1
print("Core: {}, all {} images processed".format(proc_id, len(annotation_set)))
return pq_stat


def pq_compute_multi_core(matched_annotations_list, predictions, references, categories):
cpu_num = multiprocessing.cpu_count()
# TODO support multiprocessing
# fix cpu numbers for now (DEBUGGING)
cpu_num = 1
annotations_split = np.array_split(matched_annotations_list, cpu_num)
print("Number of cores: {}, images per core: {}".format(cpu_num, len(annotations_split[0])))
workers = multiprocessing.Pool(processes=cpu_num)
processes = []
for proc_id, annotation_set in enumerate(annotations_split):
p = workers.apply_async(pq_compute_single_core, (proc_id, annotation_set, predictions, references, categories))
processes.append(p)
pq_stat = PQStat()
for p in processes:
pq_stat += p.get()
return pq_stat


def pq_compute(predictions, references, predicted_annotations, reference_annotations, categories):
matched_annotations_list = []
for pred_ann, gt_ann in zip(predicted_annotations, reference_annotations):
matched_annotations_list.append((pred_ann, gt_ann))

pq_stat = pq_compute_multi_core(matched_annotations_list, predictions, references, categories)

metrics = [("All", None), ("Things", True), ("Stuff", False)]
results = {}
for name, isthing in metrics:
results[name], per_class_results = pq_stat.pq_average(categories, isthing=isthing)
if name == "All":
results["per_class"] = per_class_results
print("{:10s}| {:>5s} {:>5s} {:>5s} {:>5s}".format("", "PQ", "SQ", "RQ", "N"))
print("-" * (10 + 7 * 4))

for name, _isthing in metrics:
print(
"{:10s}| {:5.1f} {:5.1f} {:5.1f} {:5d}".format(
name,
100 * results[name]["pq"],
100 * results[name]["sq"],
100 * results[name]["rq"],
results[name]["n"],
)
)


@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class PanopticQuality(evaluate.Metric):
def _info(self):
return evaluate.MetricInfo(
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
features=datasets.Features(
{
"predictions": datasets.Image(),
"references": datasets.Image(),
"predicted_annotations": datasets.Sequence(
{
"id": datasets.Value("int32"),
"category_id": datasets.Value("int32"),
"was_fused": datasets.Value("bool"),
"score": datasets.Value("float32"),
}
),
"reference_annotations": datasets.Sequence(
{
"id": datasets.Value("int32"),
"category_id": datasets.Value("int32"),
"iscrowd": datasets.Value("int32"),
"area": datasets.Value("int32"),
"bbox": datasets.Sequence(datasets.Value("int32")),
}
),
}
),
reference_urls=["https://github.com/cocodataset/panopticapi/blob/master/panopticapi/evaluation.py"],
)

def _compute(
self,
predictions,
references,
predicted_annotations,
reference_annotations,
categories=None,
):
result = pq_compute(predictions, references, predicted_annotations, reference_annotations, categories)

return result