Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3482 Add support for customized ConfigItem and resolver #3980

Open
wants to merge 8 commits into
base: dev
Choose a base branch
from
39 changes: 31 additions & 8 deletions monai/bundle/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import re
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Optional, Sequence, Tuple, Type, Union

from monai.bundle.config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem
from monai.bundle.reference_resolver import ReferenceResolver
Expand Down Expand Up @@ -76,6 +76,11 @@ class ConfigParser:
The current supported globals and alias names are
``{"monai": "monai", "torch": "torch", "np": "numpy", "numpy": "numpy"}``.
These are MONAI's minimal dependencies. Additional packages could be included with `globals={"itk": "itk"}`.
item_types: list of supported config item types, must be subclass of `ConfigComponent`,
`ConfigExpression`, `ConfigItem`, will check the types in order for every config item.
if `None`, default to: ``(ConfigComponent, ConfigExpression, ConfigItem)``.
resolver: manage a set of ``ConfigItem`` and resolve the references between them.
if `None`, will create a default `ReferenceResolver` instance.

See also:

Expand All @@ -96,6 +101,8 @@ def __init__(
config: Any = None,
excludes: Optional[Union[Sequence[str], str]] = None,
globals: Optional[Dict[str, Any]] = None,
item_types: Optional[Union[Sequence[Type[ConfigItem]], Type[ConfigItem]]] = None,
resolver: Optional[ReferenceResolver] = None,
):
self.config = None
self.globals: Dict[str, Any] = {}
Expand All @@ -105,9 +112,17 @@ def __init__(
if _globals is not None:
for k, v in _globals.items():
self.globals[k] = optional_import(v)[0] if isinstance(v, str) else v
self.item_types = (
(ConfigComponent, ConfigExpression, ConfigItem) if item_types is None else ensure_tuple(item_types)
)

self.locator = ComponentLocator(excludes=excludes)
self.ref_resolver = ReferenceResolver()
if resolver is not None:
if not isinstance(resolver, ReferenceResolver):
raise TypeError(f"resolver must be subclass of ReferenceResolver, but got: {type(resolver)}.")
self.ref_resolver = resolver
else:
self.ref_resolver = ReferenceResolver()
if config is None:
config = {self.meta_key: {}}
self.set(config=config)
Expand Down Expand Up @@ -309,12 +324,20 @@ def _do_parse(self, config, id: str = ""):

# copy every config item to make them independent and add them to the resolver
item_conf = deepcopy(config)
if ConfigComponent.is_instantiable(item_conf):
self.ref_resolver.add_item(ConfigComponent(config=item_conf, id=id, locator=self.locator))
elif ConfigExpression.is_expression(item_conf):
self.ref_resolver.add_item(ConfigExpression(config=item_conf, id=id, globals=self.globals))
else:
self.ref_resolver.add_item(ConfigItem(config=item_conf, id=id))
for item_type in self.item_types:
if issubclass(item_type, ConfigComponent):
if item_type.is_instantiable(item_conf):
return self.ref_resolver.add_item(item_type(config=item_conf, id=id, locator=self.locator))
continue
if issubclass(item_type, ConfigExpression):
if item_type.is_expression(item_conf):
return self.ref_resolver.add_item(item_type(config=item_conf, id=id, globals=self.globals))
continue
if issubclass(item_type, ConfigItem):
return self.ref_resolver.add_item(item_type(config=item_conf, id=id))
raise TypeError(
f"item type must be subclass of `ConfigComponent`, `ConfigExpression`, `ConfigItem`, got: {item_type}."
)

@classmethod
def load_config_file(cls, filepath: PathLike, **kwargs):
Expand Down
15 changes: 12 additions & 3 deletions tests/test_config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

from parameterized import parameterized

from monai.bundle.config_parser import ConfigParser
from monai.bundle.config_item import ConfigComponent, ConfigExpression, ConfigItem
from monai.bundle.config_parser import ConfigParser, ReferenceResolver
from monai.data import DataLoader, Dataset
from monai.transforms import Compose, LoadImaged, RandTorchVisiond
from monai.utils import min_version, optional_import
Expand Down Expand Up @@ -59,6 +60,10 @@ def __call__(self, a, b):
return self.compute(a, b)


class TestConfigComponent(ConfigComponent):
pass


TEST_CASE_2 = [
{
"basic_func": "$lambda x, y: x + y",
Expand Down Expand Up @@ -106,7 +111,7 @@ def test_config_content(self):
@parameterized.expand([TEST_CASE_1])
@skipUnless(has_tv, "Requires torchvision >= 0.8.0.")
def test_parse(self, config, expected_ids, output_types):
parser = ConfigParser(config=config, globals={"monai": "monai"})
parser = ConfigParser(config=config, globals={"monai": "monai"}, resolver=ReferenceResolver())
# test lazy instantiation with original config content
parser["transform"]["transforms"][0]["keys"] = "label1"
self.assertEqual(parser.get_parsed_content(id="transform#transforms#0").keys[0], "label1")
Expand All @@ -122,7 +127,11 @@ def test_parse(self, config, expected_ids, output_types):

@parameterized.expand([TEST_CASE_2])
def test_function(self, config):
parser = ConfigParser(config=config, globals={"TestClass": TestClass})
parser = ConfigParser(
config=config,
globals={"TestClass": TestClass},
item_types=(TestConfigComponent, ConfigExpression, ConfigItem),
)
for id in config:
func = parser.get_parsed_content(id=id)
self.assertTrue(id in parser.ref_resolver.resolved_content)
Expand Down