From 77e2b9ba165485483d257392c9ff78c535115ae8 Mon Sep 17 00:00:00 2001 From: "Shin, Eunwoo" Date: Wed, 14 Aug 2024 17:25:30 +0900 Subject: [PATCH] update update_mem_cache_img_max_size to input_size if it's none --- src/otx/core/data/module.py | 14 +++++++++++--- tests/unit/core/data/test_module.py | 4 +++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/otx/core/data/module.py b/src/otx/core/data/module.py index 06f62f1c614..c5966d51a0a 100644 --- a/src/otx/core/data/module.py +++ b/src/otx/core/data/module.py @@ -146,6 +146,14 @@ def __init__( # noqa: PLR0913 for subset_cfg in [train_subset, val_subset, test_subset, unlabeled_subset]: if subset_cfg.input_size is None: subset_cfg.input_size = input_size + + if self.mem_cache_img_max_size is None: + self.mem_cache_img_max_size = ( + (input_size, input_size) # type: ignore[assignment] + if isinstance(input_size, int) + else tuple(input_size) + ) + self.input_size = input_size if self.tile_config.enable_tiler and self.tile_config.enable_adaptive_tiling: @@ -193,7 +201,7 @@ def __init__( # noqa: PLR0913 dm_subset=dm_subset.as_dataset(), cfg_subset=config_mapping[name], mem_cache_handler=mem_cache_handler, - mem_cache_img_max_size=mem_cache_img_max_size, + mem_cache_img_max_size=self.mem_cache_img_max_size, image_color_channel=image_color_channel, stack_images=stack_images, include_polygons=include_polygons, @@ -231,7 +239,7 @@ def __init__( # noqa: PLR0913 dm_subset=dm_subset, cfg_subset=unlabeled_config, mem_cache_handler=mem_cache_handler, - mem_cache_img_max_size=mem_cache_img_max_size, + mem_cache_img_max_size=self.mem_cache_img_max_size, image_color_channel=image_color_channel, stack_images=stack_images, include_polygons=include_polygons, @@ -245,7 +253,7 @@ def __init__( # noqa: PLR0913 dm_subset=dm_subset.as_dataset(), cfg_subset=self.unlabeled_subset, mem_cache_handler=mem_cache_handler, - mem_cache_img_max_size=mem_cache_img_max_size, + mem_cache_img_max_size=self.mem_cache_img_max_size, image_color_channel=image_color_channel, stack_images=stack_images, include_polygons=include_polygons, diff --git a/tests/unit/core/data/test_module.py b/tests/unit/core/data/test_module.py index e5365406ddc..77a364504d1 100644 --- a/tests/unit/core/data/test_module.py +++ b/tests/unit/core/data/test_module.py @@ -133,6 +133,7 @@ def test_init( assert fxt_config.train_subset.input_size is None assert fxt_config.val_subset.input_size is None assert fxt_config.test_subset.input_size is None + assert module.mem_cache_img_max_size is None def test_init_input_size( self, @@ -148,7 +149,7 @@ def test_init_input_size( fxt_config.val_subset.input_size = None fxt_config.test_subset.input_size = (800, 800) - OTXDataModule( + data_module = OTXDataModule( task=OTXTaskType.MULTI_CLASS_CLS, data_format=fxt_config.data_format, data_root=fxt_config.data_root, @@ -161,6 +162,7 @@ def test_init_input_size( assert fxt_config.train_subset.input_size == (1200, 1200) assert fxt_config.val_subset.input_size == (1200, 1200) assert fxt_config.test_subset.input_size == (800, 800) + assert data_module.mem_cache_img_max_size == (1200, 1200) @pytest.fixture() def mock_adapt_input_size_to_dataset(self, mocker) -> MagicMock: