Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
johnnynunez authored Dec 21, 2023
2 parents 66f7f7a + e9f7e2b commit b5f5f3b
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 9 deletions.
27 changes: 20 additions & 7 deletions detectron2/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,7 @@ def build_batch_data_loader(
num_workers=0,
collate_fn=None,
drop_last: bool = True,
single_gpu_batch_size=None,
**kwargs,
):
"""
Expand All @@ -313,19 +314,31 @@ def build_batch_data_loader(
Must be provided iff. ``dataset`` is a map-style dataset.
total_batch_size, aspect_ratio_grouping, num_workers, collate_fn: see
:func:`build_detection_train_loader`.
single_gpu_batch_size: You can specify either `single_gpu_batch_size` or `total_batch_size`.
`single_gpu_batch_size` specifies the batch size that will be used for each gpu/process.
`total_batch_size` allows you to specify the total aggregate batch size across gpus.
It is an error to supply a value for both.
drop_last (bool): if ``True``, the dataloader will drop incomplete batches.
Returns:
iterable[list]. Length of each list is the batch size of the current
GPU. Each element in the list comes from the dataset.
"""
world_size = get_world_size()
assert (
total_batch_size > 0 and total_batch_size % world_size == 0
), "Total batch size ({}) must be divisible by the number of gpus ({}).".format(
total_batch_size, world_size
)
batch_size = total_batch_size // world_size
if single_gpu_batch_size:
if total_batch_size:
raise ValueError(
"""total_batch_size and single_gpu_batch_size are mutually incompatible.
Please specify only one. """
)
batch_size = single_gpu_batch_size
else:
world_size = get_world_size()
assert (
total_batch_size > 0 and total_batch_size % world_size == 0
), "Total batch size ({}) must be divisible by the number of gpus ({}).".format(
total_batch_size, world_size
)
batch_size = total_batch_size // world_size
logger = logging.getLogger(__name__)
logger.info("Making batched data loader with batch_size=%d", batch_size)

Expand Down
2 changes: 0 additions & 2 deletions projects/DensePose/densepose/vis/densepose_data_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def visualize(
) -> Image:
if bbox_densepose_datas is None:
return image_bgr
# pyre-fixme[23]: Unable to unpack single value, 2 were expected.
for bbox_xywh, densepose_data in zip(*bbox_densepose_datas):
matrix = densepose_data.segm.numpy()
mask = np.zeros(matrix.shape, dtype=np.uint8)
Expand All @@ -50,7 +49,6 @@ def visualize(
) -> Image:
if bbox_densepose_datas is None:
return image_bgr
# pyre-fixme[23]: Unable to unpack single value, 2 were expected.
for bbox_xywh, densepose_data in zip(*bbox_densepose_datas):
x0, y0, w, h = bbox_xywh.numpy()
x = densepose_data.x.numpy() * w / 255.0 + x0
Expand Down

0 comments on commit b5f5f3b

Please sign in to comment.