From afe9eb920646102f7e6bf0cd2115841cea2aca13 Mon Sep 17 00:00:00 2001 From: Zhicheng Yan Date: Wed, 27 Mar 2024 23:51:03 -0700 Subject: [PATCH] add DATALOADER.REPEAT_SQRT Summary: Pull Request resolved: https://github.com/facebookresearch/detectron2/pull/5245 X-link: https://github.com/fairinternal/detectron2/pull/602 For sampler **RepeatFactorTrainingSampler**, current per-category weight is computed as **1/sqrt(frequency)**. This works fine on LVIS but is not sufficient in highly imbalanced data we have for person segmentation. Thus we add an argument **DATALOADER.REPEAT_SQRT**. If false, we compute per-category weight as **1/frequency**. This change is entirely back-compatible. Reviewed By: wat3rBro Differential Revision: D55355021 Privacy Context Container: L1165023 fbshipit-source-id: 6bca2eecc3b9a7b4693a288c5779627254cd5ec5 --- detectron2/config/defaults.py | 2 ++ detectron2/data/build.py | 4 ++-- detectron2/data/samplers/distributed_sampler.py | 12 ++++++++++-- 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/detectron2/config/defaults.py b/detectron2/config/defaults.py index bd2a5f6b2d..5d97ec92d2 100644 --- a/detectron2/config/defaults.py +++ b/detectron2/config/defaults.py @@ -121,6 +121,8 @@ _C.DATALOADER.SAMPLER_TRAIN = "TrainingSampler" # Repeat threshold for RepeatFactorTrainingSampler _C.DATALOADER.REPEAT_THRESHOLD = 0.0 +# if True, take square root when computing repeating factor +_C.DATALOADER.REPEAT_SQRT = True # Tf True, when working on datasets that have instance annotations, the # training dataloader will filter out images without associated annotations _C.DATALOADER.FILTER_EMPTY_ANNOTATIONS = True diff --git a/detectron2/data/build.py b/detectron2/data/build.py index 42867687e3..a3bd94b4ed 100644 --- a/detectron2/data/build.py +++ b/detectron2/data/build.py @@ -430,7 +430,7 @@ def _build_weighted_sampler(cfg, enable_category_balance=False): """ category_repeat_factors = [ RepeatFactorTrainingSampler.repeat_factors_from_category_frequency( - dataset_dict, cfg.DATALOADER.REPEAT_THRESHOLD + dataset_dict, cfg.DATALOADER.REPEAT_THRESHOLD, sqrt=cfg.DATALOADER.REPEAT_SQRT ) for dataset_dict in dataset_name_to_dicts.values() ] @@ -482,7 +482,7 @@ def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None): sampler = TrainingSampler(len(dataset)) elif sampler_name == "RepeatFactorTrainingSampler": repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency( - dataset, cfg.DATALOADER.REPEAT_THRESHOLD + dataset, cfg.DATALOADER.REPEAT_THRESHOLD, sqrt=cfg.DATALOADER.REPEAT_SQRT ) sampler = RepeatFactorTrainingSampler(repeat_factors) elif sampler_name == "RandomSubsetTrainingSampler": diff --git a/detectron2/data/samplers/distributed_sampler.py b/detectron2/data/samplers/distributed_sampler.py index b4a9096c95..c4dc22a7a2 100644 --- a/detectron2/data/samplers/distributed_sampler.py +++ b/detectron2/data/samplers/distributed_sampler.py @@ -155,7 +155,7 @@ def __init__(self, repeat_factors, *, shuffle=True, seed=None): self._frac_part = repeat_factors - self._int_part @staticmethod - def repeat_factors_from_category_frequency(dataset_dicts, repeat_thresh): + def repeat_factors_from_category_frequency(dataset_dicts, repeat_thresh, sqrt=True): """ Compute (fractional) per-image repeat factors based on category frequency. The repeat factor for an image is a function of the frequency of the rarest @@ -169,6 +169,7 @@ def repeat_factors_from_category_frequency(dataset_dicts, repeat_thresh): repeat_thresh (float): frequency threshold below which data is repeated. If the frequency is half of `repeat_thresh`, the image will be repeated twice. + sqrt (bool): if True, apply :func:`math.sqrt` to the repeat factor. Returns: torch.Tensor: @@ -187,7 +188,14 @@ def repeat_factors_from_category_frequency(dataset_dicts, repeat_thresh): # 2. For each category c, compute the category-level repeat factor: # r(c) = max(1, sqrt(t / f(c))) category_rep = { - cat_id: max(1.0, math.sqrt(repeat_thresh / cat_freq)) + cat_id: max( + 1.0, + ( + math.sqrt(repeat_thresh / cat_freq) + if sqrt + else (repeat_thresh / cat_freq) + ), + ) for cat_id, cat_freq in category_freq.items() } for cat_id in sorted(category_rep.keys()):