Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

3942 3966 adds _requires_ and import support #3945

Merged
merged 12 commits into from
Mar 21, 2022
66 changes: 56 additions & 10 deletions monai/bundle/config_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import ast
import inspect
import os
import sys
Expand All @@ -18,7 +19,7 @@
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union

from monai.bundle.utils import EXPR_KEY
from monai.utils import ensure_tuple, instantiate
from monai.utils import ensure_tuple, first, instantiate, optional_import

__all__ = ["ComponentLocator", "ConfigItem", "ConfigExpression", "ConfigComponent"]

Expand Down Expand Up @@ -164,17 +165,24 @@ class ConfigComponent(ConfigItem, Instantiable):
Subclass of :py:class:`monai.bundle.ConfigItem`, this class uses a dictionary with string keys to
represent a component of `class` or `function` and supports instantiation.

Currently, two special keys (strings surrounded by ``_``) are defined and interpreted beyond the regular literals:
Currently, three special keys (strings surrounded by ``_``) are defined and interpreted beyond the regular literals:

- class or function identifier of the python module, specified by one of the two keys.
- ``"_target_"``: indicates build-in python classes or functions such as "LoadImageDict",
or full module name, such as "monai.transforms.LoadImageDict".
- ``"_disabled_"``: a flag to indicate whether to skip the instantiation.
- class or function identifier of the python module, specified by ``"_target_"``,
indicating a build-in python class or function such as ``"LoadImageDict"``,
or a full module name, such as ``"monai.transforms.LoadImageDict"``.
- ``"_requires_"`` (optional): specifies reference IDs (string starts with ``"@"``) or ``ConfigExpression``
of the dependencies for this ``ConfigComponent`` object. These dependencies will be
evaluated/instantiated before this object is instantiated.
- ``"_disabled_"`` (optional): a flag to indicate whether to skip the instantiation.

Other fields in the config content are input arguments to the python module.
Other fields in the config content are input arguments to the python module. ``"_requires_"`` is only useful when the
component doesn't explicitly depends on the other `ConfigItems` via its arguments,
but requires the dependencies to be instantiated/evaluated beforehand.

.. code-block:: python

from monai.bundle import ComponentLocator, ConfigComponent

locator = ComponentLocator(excludes=["modules_to_exclude"])
config = {
"_target_": "LoadImaged",
Expand All @@ -195,7 +203,7 @@ class ConfigComponent(ConfigItem, Instantiable):

"""

non_arg_keys = {"_target_", "_disabled_"}
non_arg_keys = {"_target_", "_disabled_", "_requires_"}

def __init__(
self,
Expand Down Expand Up @@ -279,7 +287,7 @@ def instantiate(self, **kwargs) -> object: # type: ignore
class ConfigExpression(ConfigItem):
"""
Subclass of :py:class:`monai.bundle.ConfigItem`, the `ConfigItem` represents an executable expression
(execute based on ``eval()``).
(execute based on ``eval()``, or import the module to the `globals` if it's an import statement).

See also:

Expand Down Expand Up @@ -308,7 +316,26 @@ class ConfigExpression(ConfigItem):

def __init__(self, config: Any, id: str = "", globals: Optional[Dict] = None) -> None:
super().__init__(config=config, id=id)
self.globals = globals
self.globals = globals if globals is not None else {}

def _parse_import_string(self, import_string: str):
"""parse single import statement such as "from monai.transforms import Resize"""
node = first(ast.iter_child_nodes(ast.parse(import_string)))
if not isinstance(node, (ast.Import, ast.ImportFrom)):
return None
if len(node.names) < 1:
return None
if len(node.names) > 1:
warnings.warn(f"ignoring multiple import alias '{import_string}'.")
name, asname = f"{node.names[0].name}", node.names[0].asname
asname = name if asname is None else f"{asname}"
if isinstance(node, ast.ImportFrom):
self.globals[asname], _ = optional_import(f"{node.module}", name=f"{name}")
return self.globals[asname]
if isinstance(node, ast.Import):
self.globals[asname], _ = optional_import(f"{name}")
return self.globals[asname]
return None

def evaluate(self, locals: Optional[Dict] = None):
"""
Expand All @@ -322,6 +349,9 @@ def evaluate(self, locals: Optional[Dict] = None):
value = self.get_config()
if not ConfigExpression.is_expression(value):
return None
optional_module = self._parse_import_string(value[len(self.prefix) :])
if optional_module is not None:
return optional_module
if not self.run_eval:
return f"{value[len(self.prefix) :]}"
return eval(value[len(self.prefix) :], self.globals, locals)
Expand All @@ -337,3 +367,19 @@ def is_expression(cls, config: Union[Dict, List, str]) -> bool:

"""
return isinstance(config, str) and config.startswith(cls.prefix)

@classmethod
def is_import_statement(cls, config: Union[Dict, List, str]) -> bool:
"""
Check whether the config is an import statement (a special case of expression).

Args:
config: input config content to check.
"""
if not cls.is_expression(config):
return False
if "import" not in config:
return False
return isinstance(
first(ast.iter_child_nodes(ast.parse(f"{config[len(cls.prefix) :]}"))), (ast.Import, ast.ImportFrom)
)
15 changes: 9 additions & 6 deletions monai/bundle/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
import json
import re
from copy import deepcopy
Expand All @@ -26,6 +25,8 @@

__all__ = ["ConfigParser"]

_default_globals = {"monai": "monai", "torch": "torch", "np": "numpy", "numpy": "numpy"}


class ConfigParser:
"""
Expand Down Expand Up @@ -74,7 +75,7 @@ class ConfigParser:
so that expressions, for example, ``"$monai.data.list_data_collate"`` can use ``monai`` modules.
The current supported globals and alias names are
``{"monai": "monai", "torch": "torch", "np": "numpy", "numpy": "numpy"}``.
These are MONAI's minimal dependencies.
These are MONAI's minimal dependencies. Additional packages could be included with `globals={"itk": "itk"}`.

See also:

Expand All @@ -96,10 +97,12 @@ def __init__(
):
self.config = None
self.globals: Dict[str, Any] = {}
globals = {"monai": "monai", "torch": "torch", "np": "numpy", "numpy": "numpy"} if globals is None else globals
if globals is not None:
for k, v in globals.items():
self.globals[k] = importlib.import_module(v) if isinstance(v, str) else v
_globals = _default_globals.copy()
if isinstance(_globals, dict) and globals is not None:
_globals.update(globals)
if _globals is not None:
for k, v in _globals.items():
self.globals[k] = optional_import(v)[0] if isinstance(v, str) else v

self.locator = ComponentLocator(excludes=excludes)
self.ref_resolver = ReferenceResolver()
Expand Down
21 changes: 14 additions & 7 deletions monai/bundle/reference_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def get_item(self, id: str, resolve: bool = False, **kwargs):

def _resolve_one_item(self, id: str, waiting_list: Optional[Set[str]] = None, **kwargs):
"""
Resolve one ``ConfigItem`` of ``id``, cache the resolved result in ``resolved_content``.
Resolve and return one ``ConfigItem`` of ``id``, cache the resolved result in ``resolved_content``.
If it has unresolved references, recursively resolve the referring items first.

Args:
Expand All @@ -111,6 +111,8 @@ def _resolve_one_item(self, id: str, waiting_list: Optional[Set[str]] = None, **
Currently support ``instantiate`` and ``eval_expr``. Both are defaulting to True.

"""
if id in self.resolved_content:
return self.resolved_content[id]
try:
item = look_up_option(id, self.items, print_all_options=False)
except ValueError as err:
Expand All @@ -121,8 +123,14 @@ def _resolve_one_item(self, id: str, waiting_list: Optional[Set[str]] = None, **
waiting_list = set()
waiting_list.add(id)

ref_ids = self.find_refs_in_config(config=item_config, id=id)
for d in ref_ids:
for t, v in self.items.items():
if (
t not in self.resolved_content
and isinstance(v, ConfigExpression)
and v.is_import_statement(v.get_config())
):
self.resolved_content[t] = v.evaluate() if kwargs.get("eval_expr", True) else v
for d in self.find_refs_in_config(config=item_config, id=id):
# if current item has reference already in the waiting list, that's circular references
if d in waiting_list:
raise ValueError(f"detected circular references '{d}' for id='{id}' in the config content.")
Expand Down Expand Up @@ -150,20 +158,19 @@ def _resolve_one_item(self, id: str, waiting_list: Optional[Set[str]] = None, **
)
else:
self.resolved_content[id] = new_config
return self.resolved_content[id]

def get_resolved_content(self, id: str, **kwargs):
"""
Get the resolved ``ConfigItem`` by id. If there are unresolved references, try to resolve them first.
Get the resolved ``ConfigItem`` by id.

Args:
id: id name of the expected item.
kwargs: additional keyword arguments to be passed to ``_resolve_one_item``.
Currently support ``instantiate`` and ``eval_expr``. Both are defaulting to True.

"""
if id not in self.resolved_content:
self._resolve_one_item(id=id, **kwargs)
return self.resolved_content[id]
return self._resolve_one_item(id=id, **kwargs)

@classmethod
def match_refs_pattern(cls, value: str) -> Set[str]:
Expand Down
19 changes: 19 additions & 0 deletions tests/test_config_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,25 @@ def test_lazy_instantiation(self):
self.assertTrue(isinstance(ret, DataLoader))
self.assertEqual(ret.batch_size, 4)

@parameterized.expand([("$import json", "json"), ("$import json as j", "j")])
def test_import(self, stmt, mod_name):
test_globals = {}
ConfigExpression(id="", config=stmt, globals=test_globals).evaluate()
self.assertTrue(callable(test_globals[mod_name].dump))

@parameterized.expand(
[
("$from json import dump", "dump"),
("$from json import dump, dumps", "dump"),
("$from json import dump as jd", "jd"),
("$from json import dump as jd, dumps as ds", "jd"),
]
)
def test_import_from(self, stmt, mod_name):
test_globals = {}
ConfigExpression(id="", config=stmt, globals=test_globals).evaluate()
self.assertTrue(callable(test_globals[mod_name]))


if __name__ == "__main__":
unittest.main()
10 changes: 10 additions & 0 deletions tests/testing_data/inference.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
{
"dataset_dir": "/workspace/data/Task09_Spleen",
"import_glob": "$import glob",
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
"set_seed": "$monai.utils.set_determinism(0)",
"print_test_name": "$print('json_test')",
"print_glob_file": "$print(glob.__file__)",
"network_def": {
"_target_": "UNet",
"spatial_dims": 3,
Expand Down Expand Up @@ -94,6 +98,12 @@
},
"evaluator": {
"_target_": "SupervisedEvaluator",
"_requires_": [
"@set_seed",
"@print_test_name",
"@print_glob_file",
"$print('test_in_line_json')"
],
"device": "@device",
"val_data_loader": "@dataloader",
"network": "@network",
Expand Down
6 changes: 6 additions & 0 deletions tests/testing_data/inference.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
---
dataset_dir: "/workspace/data/Task09_Spleen"
device: "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')"
set_seed: "$monai.utils.set_determinism(0)"
print_test_name: "$print('yaml_test')"
network_def:
_target_: UNet
spatial_dims: 3
Expand Down Expand Up @@ -67,6 +69,10 @@ postprocessing:
output_dir: "@_meta_#output_dir"
evaluator:
_target_: SupervisedEvaluator
_requires_:
- "$print('test_in_line_yaml')"
- "@set_seed"
- "@print_test_name"
device: "@device"
val_data_loader: "@dataloader"
network: "@network"
Expand Down