Skip to content

Commit

Permalink
update update_mem_cache_img_max_size to input_size if it's none
Browse files Browse the repository at this point in the history
  • Loading branch information
eunwoosh committed Aug 14, 2024
1 parent 617b5a7 commit 77e2b9b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
14 changes: 11 additions & 3 deletions src/otx/core/data/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,14 @@ def __init__( # noqa: PLR0913
for subset_cfg in [train_subset, val_subset, test_subset, unlabeled_subset]:
if subset_cfg.input_size is None:
subset_cfg.input_size = input_size

if self.mem_cache_img_max_size is None:
self.mem_cache_img_max_size = (
(input_size, input_size) # type: ignore[assignment]
if isinstance(input_size, int)
else tuple(input_size)
)

self.input_size = input_size

if self.tile_config.enable_tiler and self.tile_config.enable_adaptive_tiling:
Expand Down Expand Up @@ -193,7 +201,7 @@ def __init__( # noqa: PLR0913
dm_subset=dm_subset.as_dataset(),
cfg_subset=config_mapping[name],
mem_cache_handler=mem_cache_handler,
mem_cache_img_max_size=mem_cache_img_max_size,
mem_cache_img_max_size=self.mem_cache_img_max_size,
image_color_channel=image_color_channel,
stack_images=stack_images,
include_polygons=include_polygons,
Expand Down Expand Up @@ -231,7 +239,7 @@ def __init__( # noqa: PLR0913
dm_subset=dm_subset,
cfg_subset=unlabeled_config,
mem_cache_handler=mem_cache_handler,
mem_cache_img_max_size=mem_cache_img_max_size,
mem_cache_img_max_size=self.mem_cache_img_max_size,
image_color_channel=image_color_channel,
stack_images=stack_images,
include_polygons=include_polygons,
Expand All @@ -245,7 +253,7 @@ def __init__( # noqa: PLR0913
dm_subset=dm_subset.as_dataset(),
cfg_subset=self.unlabeled_subset,
mem_cache_handler=mem_cache_handler,
mem_cache_img_max_size=mem_cache_img_max_size,
mem_cache_img_max_size=self.mem_cache_img_max_size,
image_color_channel=image_color_channel,
stack_images=stack_images,
include_polygons=include_polygons,
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/core/data/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def test_init(
assert fxt_config.train_subset.input_size is None
assert fxt_config.val_subset.input_size is None
assert fxt_config.test_subset.input_size is None
assert module.mem_cache_img_max_size is None

def test_init_input_size(
self,
Expand All @@ -148,7 +149,7 @@ def test_init_input_size(
fxt_config.val_subset.input_size = None
fxt_config.test_subset.input_size = (800, 800)

OTXDataModule(
data_module = OTXDataModule(
task=OTXTaskType.MULTI_CLASS_CLS,
data_format=fxt_config.data_format,
data_root=fxt_config.data_root,
Expand All @@ -161,6 +162,7 @@ def test_init_input_size(
assert fxt_config.train_subset.input_size == (1200, 1200)
assert fxt_config.val_subset.input_size == (1200, 1200)
assert fxt_config.test_subset.input_size == (800, 800)
assert data_module.mem_cache_img_max_size == (1200, 1200)

@pytest.fixture()
def mock_adapt_input_size_to_dataset(self, mocker) -> MagicMock:
Expand Down

0 comments on commit 77e2b9b

Please sign in to comment.