From 59b65df46898c7c0dcd04f08c6568a57b5c4f061 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 22 Mar 2022 18:51:05 +0800 Subject: [PATCH] [DLMED] add customized configitem and resolver Signed-off-by: Nic Ma --- monai/bundle/config_parser.py | 39 ++++++++++++++++++++++++++++------- tests/test_config_parser.py | 17 +++++++++++---- 2 files changed, 44 insertions(+), 12 deletions(-) diff --git a/monai/bundle/config_parser.py b/monai/bundle/config_parser.py index 23d4ac7c55..ff561d6d15 100644 --- a/monai/bundle/config_parser.py +++ b/monai/bundle/config_parser.py @@ -13,7 +13,7 @@ import re from copy import deepcopy from pathlib import Path -from typing import Any, Dict, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union from monai.bundle.config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem from monai.bundle.reference_resolver import ReferenceResolver @@ -76,6 +76,11 @@ class ConfigParser: The current supported globals and alias names are ``{"monai": "monai", "torch": "torch", "np": "numpy", "numpy": "numpy"}``. These are MONAI's minimal dependencies. Additional packages could be included with `globals={"itk": "itk"}`. + item_types: list of supported config item types, must be subclass of `ConfigComponent`, + `ConfigExpression`, `ConfigItem`, will check the types in order for every config item. + if `None`, default to: ``(ConfigComponent, ConfigExpression, ConfigItem)``. + resolver: manage a set of ``ConfigItem`` and resolve the references between them. + if `None`, will create a default `ReferenceResolver` instance. See also: @@ -94,6 +99,8 @@ def __init__( config: Any = None, excludes: Optional[Union[Sequence[str], str]] = None, globals: Optional[Dict[str, Any]] = None, + item_types: Optional[Union[Sequence[Type[ConfigItem]], Type[ConfigItem]]] = None, + resolver: Optional[ReferenceResolver] = None, ): self.config = None self.globals: Dict[str, Any] = {} @@ -103,9 +110,17 @@ def __init__( if _globals is not None: for k, v in _globals.items(): self.globals[k] = optional_import(v)[0] if isinstance(v, str) else v + self.item_types = ( + (ConfigComponent, ConfigExpression, ConfigItem) if item_types is None else ensure_tuple(item_types) + ) self.locator = ComponentLocator(excludes=excludes) - self.ref_resolver = ReferenceResolver() + if resolver is not None: + if not isinstance(resolver, ReferenceResolver): + raise TypeError(f"resolver must be subclass of ReferenceResolver, but got: {type(resolver)}.") + self.ref_resolver = resolver + else: + self.ref_resolver = ReferenceResolver() if config is None: config = {self.meta_key: {}} self.set(config=config) @@ -295,12 +310,20 @@ def _do_parse(self, config, id: str = ""): # copy every config item to make them independent and add them to the resolver item_conf = deepcopy(config) - if ConfigComponent.is_instantiable(item_conf): - self.ref_resolver.add_item(ConfigComponent(config=item_conf, id=id, locator=self.locator)) - elif ConfigExpression.is_expression(item_conf): - self.ref_resolver.add_item(ConfigExpression(config=item_conf, id=id, globals=self.globals)) - else: - self.ref_resolver.add_item(ConfigItem(config=item_conf, id=id)) + for item_type in self.item_types: + if issubclass(item_type, ConfigComponent): + if item_type.is_instantiable(item_conf): + return self.ref_resolver.add_item(item_type(config=item_conf, id=id, locator=self.locator)) + continue + if issubclass(item_type, ConfigExpression): + if item_type.is_expression(item_conf): + return self.ref_resolver.add_item(item_type(config=item_conf, id=id, globals=self.globals)) + continue + if issubclass(item_type, ConfigItem): + return self.ref_resolver.add_item(item_type(config=item_conf, id=id)) + raise TypeError( + f"item type must be subclass of `ConfigComponent`, `ConfigExpression`, `ConfigItem`, got: {item_type}." + ) @classmethod def load_config_file(cls, filepath: PathLike, **kwargs): diff --git a/tests/test_config_parser.py b/tests/test_config_parser.py index ce98be1214..e85f13b6c6 100644 --- a/tests/test_config_parser.py +++ b/tests/test_config_parser.py @@ -14,7 +14,8 @@ from parameterized import parameterized -from monai.bundle.config_parser import ConfigParser +from monai.bundle.config_item import ConfigComponent, ConfigExpression, ConfigItem +from monai.bundle.config_parser import ConfigParser, ReferenceResolver from monai.data import DataLoader, Dataset from monai.transforms import Compose, LoadImaged, RandTorchVisiond from monai.utils import min_version, optional_import @@ -57,6 +58,10 @@ def __call__(self, a, b): return self.compute(a, b) +class TestConfigComponent(ConfigComponent): + pass + + TEST_CASE_2 = [ { "basic_func": "$lambda x, y: x + y", @@ -73,7 +78,7 @@ def __call__(self, a, b): ] -class TestConfigComponent(unittest.TestCase): +class TestConfigParser(unittest.TestCase): def test_config_content(self): test_config = {"preprocessing": [{"_target_": "LoadImage"}], "dataset": {"_target_": "Dataset"}} parser = ConfigParser(config=test_config) @@ -94,7 +99,7 @@ def test_config_content(self): @parameterized.expand([TEST_CASE_1]) @skipUnless(has_tv, "Requires torchvision >= 0.8.0.") def test_parse(self, config, expected_ids, output_types): - parser = ConfigParser(config=config, globals={"monai": "monai"}) + parser = ConfigParser(config=config, globals={"monai": "monai"}, resolver=ReferenceResolver()) # test lazy instantiation with original config content parser["transform"]["transforms"][0]["keys"] = "label1" self.assertEqual(parser.get_parsed_content(id="transform#transforms#0").keys[0], "label1") @@ -110,7 +115,11 @@ def test_parse(self, config, expected_ids, output_types): @parameterized.expand([TEST_CASE_2]) def test_function(self, config): - parser = ConfigParser(config=config, globals={"TestClass": TestClass}) + parser = ConfigParser( + config=config, + globals={"TestClass": TestClass}, + item_types=(TestConfigComponent, ConfigExpression, ConfigItem), + ) for id in config: func = parser.get_parsed_content(id=id) self.assertTrue(id in parser.ref_resolver.resolved_content)