From cd71c12a946b58bf9ac9d76816476468094304b4 Mon Sep 17 00:00:00 2001 From: bobo0810 <1055271769@qq.com> Date: Mon, 7 Aug 2023 21:14:11 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E3=80=90Feature=E3=80=91Class=20balanced?= =?UTF-8?q?=20sampling=20at=20the=20Batch=20level?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- mmpretrain/datasets/samplers/__init__.py | 3 +- mmpretrain/datasets/samplers/batch_balance.py | 107 ++++++++++++++++++ 2 files changed, 109 insertions(+), 1 deletion(-) create mode 100644 mmpretrain/datasets/samplers/batch_balance.py diff --git a/mmpretrain/datasets/samplers/__init__.py b/mmpretrain/datasets/samplers/__init__.py index 2bccf9c3465..f8d93f1dbce 100644 --- a/mmpretrain/datasets/samplers/__init__.py +++ b/mmpretrain/datasets/samplers/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .repeat_aug import RepeatAugSampler from .sequential import SequentialSampler +from .batch_balance import BatchBalanceSampler -__all__ = ['RepeatAugSampler', 'SequentialSampler'] +__all__ = ['RepeatAugSampler', 'SequentialSampler', 'BatchBalanceSampler'] diff --git a/mmpretrain/datasets/samplers/batch_balance.py b/mmpretrain/datasets/samplers/batch_balance.py new file mode 100644 index 00000000000..386fc2e31aa --- /dev/null +++ b/mmpretrain/datasets/samplers/batch_balance.py @@ -0,0 +1,107 @@ +from typing import Iterator +from mmengine.dataset import DefaultSampler +from mmpretrain.registry import DATA_SAMPLERS +import numpy as np +import math +import collections +import torch + + +@DATA_SAMPLERS.register_module() +class BatchBalanceSampler(DefaultSampler): + """ + refer: https://github.com/KevinMusgrave/pytorch-metric-learning/blob/v2.3.0/src/pytorch_metric_learning/samplers/num_per_class_sampler.py + + At every iteration, this will return m samples per class. For example, + if dataloader's batchsize is 100, and m = 5, then 20 classes with 5 samples + each will be returned + + + Args: + num_per_class: number of samples per class in a batch + + Examples: + train_dataloader = dict( + xxxx, + sampler=dict(type="BatchBalanceSampler", num_per_class=4), + ) + """ + + def __init__(self, num_per_class, **kwargs) -> None: + super().__init__(**kwargs) + self.num_per_class = int(num_per_class) + self.labels_to_indices = self.get_labels_to_indices( + self.dataset.get_gt_labels() + ) + self.labels = list(self.labels_to_indices.keys()) # labels index list + self.length_of_single_pass = self.num_per_class * len(self.labels) + + self.total_size = len(self.dataset) + # It must be an integer multiple of length_of_single_pass + if self.length_of_single_pass < self.total_size: + self.total_size -= (self.total_size) % (self.length_of_single_pass) + # The number of samples in this rank + self.num_samples = math.ceil((self.total_size - self.rank) / self.world_size) + + def __len__(self) -> int: + """The number of samples in this rank.""" + return self.num_samples + + def __iter__(self) -> Iterator[int]: + indices = [0] * self.total_size + i = 0 + num_iters = self.calculate_num_iters() + for _ in range(num_iters): + np.random.shuffle(self.labels) + curr_label_set = self.labels + for label in curr_label_set: + t = self.labels_to_indices[label] # List of all sample indexes corresponding to the current label + indices[i : i + self.num_per_class] = self.safe_random_choice( + t, size=self.num_per_class + ) + i += self.num_per_class + # subsample + indices = indices[self.rank : self.total_size : self.world_size] + + return iter(indices) + + def calculate_num_iters(self): + divisor = self.length_of_single_pass + return self.total_size // divisor if divisor < self.total_size else 1 + + def safe_random_choice(self, input_data, size): + """ + Randomly samples without replacement from a sequence. It is "safe" because + if len(input_data) < size, it will randomly sample WITH replacement + Args: + input_data is a sequence, like a torch tensor, numpy array, + python list, tuple etc + size is the number of elements to randomly sample from input_data + Returns: + An array of size "size", randomly sampled from input_data + """ + replace = len(input_data) < size + return np.random.choice(input_data, size=size, replace=replace) + + def get_labels_to_indices(self, labels): + """ + Creates labels_to_indices, which is a dictionary mapping each label + to a numpy array of indices that will be used to index into self.dataset + + {labels_index:Index of samples belonging to the category} + + eg: { + "0":[1,3,6,8], + "1":[2,4,5,7], + "2":[0,9,10] + } + + """ + if torch.is_tensor(labels): + labels = labels.cpu().numpy() + labels_to_indices = collections.defaultdict(list) + for i, label in enumerate(labels): + labels_to_indices[label].append(i) + for k, v in labels_to_indices.items(): + labels_to_indices[k] = np.array(v, dtype=int) + return labels_to_indices From ec3a59d23ea659590c6878b4ad2649772cbbcaaa Mon Sep 17 00:00:00 2001 From: bobo0810 <1055271769@qq.com> Date: Wed, 9 Aug 2023 11:12:29 +0800 Subject: [PATCH 2/2] code formatting --- mmpretrain/datasets/samplers/__init__.py | 2 +- mmpretrain/datasets/samplers/batch_balance.py | 45 ++++++++++--------- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/mmpretrain/datasets/samplers/__init__.py b/mmpretrain/datasets/samplers/__init__.py index f8d93f1dbce..56e8af45695 100644 --- a/mmpretrain/datasets/samplers/__init__.py +++ b/mmpretrain/datasets/samplers/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from .batch_balance import BatchBalanceSampler from .repeat_aug import RepeatAugSampler from .sequential import SequentialSampler -from .batch_balance import BatchBalanceSampler __all__ = ['RepeatAugSampler', 'SequentialSampler', 'BatchBalanceSampler'] diff --git a/mmpretrain/datasets/samplers/batch_balance.py b/mmpretrain/datasets/samplers/batch_balance.py index 386fc2e31aa..959f2081ea6 100644 --- a/mmpretrain/datasets/samplers/batch_balance.py +++ b/mmpretrain/datasets/samplers/batch_balance.py @@ -1,16 +1,21 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import collections +import math from typing import Iterator -from mmengine.dataset import DefaultSampler -from mmpretrain.registry import DATA_SAMPLERS + import numpy as np -import math -import collections import torch +from mmengine.dataset import DefaultSampler + +from mmpretrain.registry import DATA_SAMPLERS @DATA_SAMPLERS.register_module() class BatchBalanceSampler(DefaultSampler): """ - refer: https://github.com/KevinMusgrave/pytorch-metric-learning/blob/v2.3.0/src/pytorch_metric_learning/samplers/num_per_class_sampler.py + refer: https://github.com/KevinMusgrave/pytorch-metric-learning/ + blob/v2.3.0/src/pytorch_metric_learning/samplers/num_per_class_sampler.py + At every iteration, this will return m samples per class. For example, if dataloader's batchsize is 100, and m = 5, then 20 classes with 5 samples @@ -31,9 +36,8 @@ def __init__(self, num_per_class, **kwargs) -> None: super().__init__(**kwargs) self.num_per_class = int(num_per_class) self.labels_to_indices = self.get_labels_to_indices( - self.dataset.get_gt_labels() - ) - self.labels = list(self.labels_to_indices.keys()) # labels index list + self.dataset.get_gt_labels()) + self.labels = list(self.labels_to_indices.keys()) # labels index list self.length_of_single_pass = self.num_per_class * len(self.labels) self.total_size = len(self.dataset) @@ -41,7 +45,8 @@ def __init__(self, num_per_class, **kwargs) -> None: if self.length_of_single_pass < self.total_size: self.total_size -= (self.total_size) % (self.length_of_single_pass) # The number of samples in this rank - self.num_samples = math.ceil((self.total_size - self.rank) / self.world_size) + self.num_samples = math.ceil( + (self.total_size - self.rank) / self.world_size) def __len__(self) -> int: """The number of samples in this rank.""" @@ -55,13 +60,13 @@ def __iter__(self) -> Iterator[int]: np.random.shuffle(self.labels) curr_label_set = self.labels for label in curr_label_set: - t = self.labels_to_indices[label] # List of all sample indexes corresponding to the current label - indices[i : i + self.num_per_class] = self.safe_random_choice( - t, size=self.num_per_class - ) + # List of all sample indexes corresponding to the current label + t = self.labels_to_indices[label] + indices[i:i + self.num_per_class] = self.safe_random_choice( + t, size=self.num_per_class) i += self.num_per_class # subsample - indices = indices[self.rank : self.total_size : self.world_size] + indices = indices[self.rank:self.total_size:self.world_size] return iter(indices) @@ -70,8 +75,9 @@ def calculate_num_iters(self): return self.total_size // divisor if divisor < self.total_size else 1 def safe_random_choice(self, input_data, size): - """ - Randomly samples without replacement from a sequence. It is "safe" because + """Randomly samples without replacement from a sequence. + + It is "safe" because if len(input_data) < size, it will randomly sample WITH replacement Args: input_data is a sequence, like a torch tensor, numpy array, @@ -84,9 +90,9 @@ def safe_random_choice(self, input_data, size): return np.random.choice(input_data, size=size, replace=replace) def get_labels_to_indices(self, labels): - """ - Creates labels_to_indices, which is a dictionary mapping each label - to a numpy array of indices that will be used to index into self.dataset + """Creates labels_to_indices, which is a dictionary mapping each label + to a numpy array of indices that will be used to index into + self.dataset. {labels_index:Index of samples belonging to the category} @@ -95,7 +101,6 @@ def get_labels_to_indices(self, labels): "1":[2,4,5,7], "2":[0,9,10] } - """ if torch.is_tensor(labels): labels = labels.cpu().numpy()