Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
johnnynunez authored Apr 9, 2024
2 parents 173b59f + b7c7f4b commit 86267e6
Show file tree
Hide file tree
Showing 136 changed files with 791 additions and 185 deletions.
25 changes: 19 additions & 6 deletions demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
import argparse
import glob
import multiprocessing as mp
import numpy as np
import os
import tempfile
import time
import warnings

import cv2
import numpy as np
import tqdm

from detectron2.config import get_cfg
Expand All @@ -31,7 +32,9 @@ def setup_cfg(args):
# Set score_threshold for builtin models
cfg.MODEL.RETINANET.SCORE_THRESH_TEST = args.confidence_threshold
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold
cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = args.confidence_threshold
cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = (
args.confidence_threshold
)
cfg.freeze()
return cfg

Expand All @@ -44,7 +47,9 @@ def get_parser():
metavar="FILE",
help="path to config file",
)
parser.add_argument("--webcam", action="store_true", help="Take inputs from webcam.")
parser.add_argument(
"--webcam", action="store_true", help="Take inputs from webcam."
)
parser.add_argument("--video-input", help="Path to video file.")
parser.add_argument(
"--input",
Expand Down Expand Up @@ -90,7 +95,7 @@ def test_opencv_video_format(codec, file_ext):
return False


if __name__ == "__main__":
def main() -> None:
mp.set_start_method("spawn", force=True)
args = get_parser().parse_args()
setup_logger(name="fvcore")
Expand Down Expand Up @@ -125,7 +130,9 @@ def test_opencv_video_format(codec, file_ext):
assert os.path.isdir(args.output), args.output
out_filename = os.path.join(args.output, os.path.basename(path))
else:
assert len(args.input) == 1, "Please specify a directory with args.output"
assert (
len(args.input) == 1
), "Please specify a directory with args.output"
out_filename = args.output
visualized_output.save(out_filename)
else:
Expand All @@ -152,7 +159,9 @@ def test_opencv_video_format(codec, file_ext):
num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
basename = os.path.basename(args.video_input)
codec, file_ext = (
("x264", ".mkv") if test_opencv_video_format("x264", ".mkv") else ("mp4v", ".mp4")
("x264", ".mkv")
if test_opencv_video_format("x264", ".mkv")
else ("mp4v", ".mp4")
)
if codec == ".mp4v":
warnings.warn("x264 codec not available, switching to mp4v")
Expand Down Expand Up @@ -186,3 +195,7 @@ def test_opencv_video_format(codec, file_ext):
output_file.release()
else:
cv2.destroyAllWindows()


if __name__ == "__main__":
main() # pragma: no cover
2 changes: 2 additions & 0 deletions detectron2/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@
_C.DATALOADER.SAMPLER_TRAIN = "TrainingSampler"
# Repeat threshold for RepeatFactorTrainingSampler
_C.DATALOADER.REPEAT_THRESHOLD = 0.0
# if True, take square root when computing repeating factor
_C.DATALOADER.REPEAT_SQRT = True
# Tf True, when working on datasets that have instance annotations, the
# training dataloader will filter out images without associated annotations
_C.DATALOADER.FILTER_EMPTY_ANNOTATIONS = True
Expand Down
12 changes: 10 additions & 2 deletions detectron2/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ def build_batch_data_loader(
collate_fn=None,
drop_last: bool = True,
single_gpu_batch_size=None,
seed=None,
**kwargs,
):
"""
Expand Down Expand Up @@ -347,13 +348,19 @@ def build_batch_data_loader(
else:
dataset = ToIterableDataset(dataset, sampler, shard_chunk_size=batch_size)

generator = None
if seed is not None:
generator = torch.Generator()
generator.manual_seed(seed)

if aspect_ratio_grouping:
assert drop_last, "Aspect ratio grouping will drop incomplete batches."
data_loader = torchdata.DataLoader(
dataset,
num_workers=num_workers,
collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements
worker_init_fn=worker_init_reset_seed,
generator=generator,
**kwargs
) # yield individual mapped dict
data_loader = AspectRatioGroupedDataset(data_loader, batch_size)
Expand All @@ -368,6 +375,7 @@ def build_batch_data_loader(
num_workers=num_workers,
collate_fn=trivial_batch_collator if collate_fn is None else collate_fn,
worker_init_fn=worker_init_reset_seed,
generator=generator,
**kwargs
)

Expand Down Expand Up @@ -422,7 +430,7 @@ def _build_weighted_sampler(cfg, enable_category_balance=False):
"""
category_repeat_factors = [
RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
dataset_dict, cfg.DATALOADER.REPEAT_THRESHOLD
dataset_dict, cfg.DATALOADER.REPEAT_THRESHOLD, sqrt=cfg.DATALOADER.REPEAT_SQRT
)
for dataset_dict in dataset_name_to_dicts.values()
]
Expand Down Expand Up @@ -474,7 +482,7 @@ def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None):
sampler = TrainingSampler(len(dataset))
elif sampler_name == "RepeatFactorTrainingSampler":
repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency(
dataset, cfg.DATALOADER.REPEAT_THRESHOLD
dataset, cfg.DATALOADER.REPEAT_THRESHOLD, sqrt=cfg.DATALOADER.REPEAT_SQRT
)
sampler = RepeatFactorTrainingSampler(repeat_factors)
elif sampler_name == "RandomSubsetTrainingSampler":
Expand Down
40 changes: 28 additions & 12 deletions detectron2/data/datasets/cityscapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,17 @@
import json
import logging
import multiprocessing as mp
import numpy as np
import os
from itertools import chain

import numpy as np
import pycocotools.mask as mask_util
from PIL import Image

from detectron2.structures import BoxMode
from detectron2.utils.comm import get_world_size
from detectron2.utils.file_io import PathManager
from detectron2.utils.logger import setup_logger
from PIL import Image

try:
import cv2 # noqa
Expand All @@ -39,7 +40,9 @@ def _get_cityscapes_files(image_dir, gt_dir):
assert basename.endswith(suffix), basename
basename = basename[: -len(suffix)]

instance_file = os.path.join(city_gt_dir, basename + "gtFine_instanceIds.png")
instance_file = os.path.join(
city_gt_dir, basename + "gtFine_instanceIds.png"
)
label_file = os.path.join(city_gt_dir, basename + "gtFine_labelIds.png")
json_file = os.path.join(city_gt_dir, basename + "gtFine_polygons.json")

Expand Down Expand Up @@ -76,7 +79,9 @@ def load_cityscapes_instances(image_dir, gt_dir, from_json=True, to_polygons=Tru
pool = mp.Pool(processes=max(mp.cpu_count() // get_world_size() // 2, 4))

ret = pool.map(
functools.partial(_cityscapes_files_to_dict, from_json=from_json, to_polygons=to_polygons),
functools.partial(
_cityscapes_files_to_dict, from_json=from_json, to_polygons=to_polygons
),
files,
)
logger.info("Loaded {} images from {}".format(len(ret), image_dir))
Expand Down Expand Up @@ -105,7 +110,9 @@ def load_cityscapes_semantic(image_dir, gt_dir):
ret = []
# gt_dir is small and contain many small files. make sense to fetch to local first
gt_dir = PathManager.get_local_path(gt_dir)
for image_file, _, label_file, json_file in _get_cityscapes_files(image_dir, gt_dir):
for image_file, _, label_file, json_file in _get_cityscapes_files(
image_dir, gt_dir
):
label_file = label_file.replace("labelIds", "labelTrainIds")

with PathManager.open(json_file, "r") as f:
Expand Down Expand Up @@ -209,7 +216,9 @@ def _cityscapes_files_to_dict(files, from_json, to_polygons):
elif isinstance(poly_wo_overlaps, MultiPolygon):
poly_list = poly_wo_overlaps.geoms
else:
raise NotImplementedError("Unknown geometric structure {}".format(poly_wo_overlaps))
raise NotImplementedError(
"Unknown geometric structure {}".format(poly_wo_overlaps)
)

poly_coord = []
for poly_el in poly_list:
Expand Down Expand Up @@ -263,9 +272,9 @@ def _cityscapes_files_to_dict(files, from_json, to_polygons):
if to_polygons:
# This conversion comes from D4809743 and D5171122,
# when Mask-RCNN was first developed.
contours = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)[
-2
]
contours = cv2.findContours(
mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE
)[-2]
polygons = [c.reshape(-1).tolist() for c in contours if len(c) >= 3]
# opencv's can produce invalid polygons
if len(polygons) == 0:
Expand All @@ -278,7 +287,8 @@ def _cityscapes_files_to_dict(files, from_json, to_polygons):
return ret


if __name__ == "__main__":
def main() -> None:
global logger, labels
"""
Test the cityscapes dataset loader.
Expand All @@ -293,9 +303,9 @@ def _cityscapes_files_to_dict(files, from_json, to_polygons):
parser.add_argument("gt_dir")
parser.add_argument("--type", choices=["instance", "semantic"], default="instance")
args = parser.parse_args()
from cityscapesscripts.helpers.labels import labels
from detectron2.data.catalog import Metadata
from detectron2.utils.visualizer import Visualizer
from cityscapesscripts.helpers.labels import labels

logger = setup_logger(name=__name__)

Expand All @@ -308,7 +318,9 @@ def _cityscapes_files_to_dict(files, from_json, to_polygons):
)
logger.info("Done loading {} samples.".format(len(dicts)))

thing_classes = [k.name for k in labels if k.hasInstances and not k.ignoreInEval]
thing_classes = [
k.name for k in labels if k.hasInstances and not k.ignoreInEval
]
meta = Metadata().set(thing_classes=thing_classes)

else:
Expand All @@ -327,3 +339,7 @@ def _cityscapes_files_to_dict(files, from_json, to_polygons):
# cv2.waitKey()
fpath = os.path.join(dirname, os.path.basename(d["file_name"]))
vis.save(fpath)


if __name__ == "__main__":
main() # pragma: no cover
Loading

0 comments on commit 86267e6

Please sign in to comment.