Skip to content

Commit

Permalink
3942 3966 adds _requires_ and import support (#3945)
Browse files Browse the repository at this point in the history
* fixes #3942

Signed-off-by: Wenqi Li <[email protected]>

* adds import statement

Signed-off-by: Wenqi Li <[email protected]>

* more tests

Signed-off-by: Wenqi Li <[email protected]>

* update to resolve import statement

Signed-off-by: Wenqi Li <[email protected]>

* update based on comments

Signed-off-by: Wenqi Li <[email protected]>

* update

Signed-off-by: Wenqi Li <[email protected]>

* update based on comments

Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli authored Mar 21, 2022
1 parent be25a72 commit ec19406
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 24 deletions.
64 changes: 55 additions & 9 deletions monai/bundle/config_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import ast
import inspect
import os
import sys
Expand All @@ -18,7 +19,7 @@
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union

from monai.bundle.utils import EXPR_KEY
from monai.utils import ensure_tuple, instantiate
from monai.utils import ensure_tuple, first, instantiate, optional_import

__all__ = ["ComponentLocator", "ConfigItem", "ConfigExpression", "ConfigComponent"]

Expand Down Expand Up @@ -164,17 +165,24 @@ class ConfigComponent(ConfigItem, Instantiable):
Subclass of :py:class:`monai.bundle.ConfigItem`, this class uses a dictionary with string keys to
represent a component of `class` or `function` and supports instantiation.
Currently, two special keys (strings surrounded by ``_``) are defined and interpreted beyond the regular literals:
Currently, three special keys (strings surrounded by ``_``) are defined and interpreted beyond the regular literals:
- class or function identifier of the python module, specified by one of the two keys.
- ``"_target_"``: indicates build-in python classes or functions such as "LoadImageDict",
or full module name, such as "monai.transforms.LoadImageDict".
- ``"_disabled_"``: a flag to indicate whether to skip the instantiation.
- class or function identifier of the python module, specified by ``"_target_"``,
indicating a build-in python class or function such as ``"LoadImageDict"``,
or a full module name, such as ``"monai.transforms.LoadImageDict"``.
- ``"_requires_"`` (optional): specifies reference IDs (string starts with ``"@"``) or ``ConfigExpression``
of the dependencies for this ``ConfigComponent`` object. These dependencies will be
evaluated/instantiated before this object is instantiated. It is useful when the
component doesn't explicitly depends on the other `ConfigItems` via its arguments,
but requires the dependencies to be instantiated/evaluated beforehand.
- ``"_disabled_"`` (optional): a flag to indicate whether to skip the instantiation.
Other fields in the config content are input arguments to the python module.
.. code-block:: python
from monai.bundle import ComponentLocator, ConfigComponent
locator = ComponentLocator(excludes=["modules_to_exclude"])
config = {
"_target_": "LoadImaged",
Expand All @@ -195,7 +203,7 @@ class ConfigComponent(ConfigItem, Instantiable):
"""

non_arg_keys = {"_target_", "_disabled_"}
non_arg_keys = {"_target_", "_disabled_", "_requires_"}

def __init__(
self,
Expand Down Expand Up @@ -279,7 +287,7 @@ def instantiate(self, **kwargs) -> object: # type: ignore
class ConfigExpression(ConfigItem):
"""
Subclass of :py:class:`monai.bundle.ConfigItem`, the `ConfigItem` represents an executable expression
(execute based on ``eval()``).
(execute based on ``eval()``, or import the module to the `globals` if it's an import statement).
See also:
Expand Down Expand Up @@ -308,7 +316,26 @@ class ConfigExpression(ConfigItem):

def __init__(self, config: Any, id: str = "", globals: Optional[Dict] = None) -> None:
super().__init__(config=config, id=id)
self.globals = globals
self.globals = globals if globals is not None else {}

def _parse_import_string(self, import_string: str):
"""parse single import statement such as "from monai.transforms import Resize"""
node = first(ast.iter_child_nodes(ast.parse(import_string)))
if not isinstance(node, (ast.Import, ast.ImportFrom)):
return None
if len(node.names) < 1:
return None
if len(node.names) > 1:
warnings.warn(f"ignoring multiple import alias '{import_string}'.")
name, asname = f"{node.names[0].name}", node.names[0].asname
asname = name if asname is None else f"{asname}"
if isinstance(node, ast.ImportFrom):
self.globals[asname], _ = optional_import(f"{node.module}", name=f"{name}")
return self.globals[asname]
if isinstance(node, ast.Import):
self.globals[asname], _ = optional_import(f"{name}")
return self.globals[asname]
return None

def evaluate(self, locals: Optional[Dict] = None):
"""
Expand All @@ -322,6 +349,9 @@ def evaluate(self, locals: Optional[Dict] = None):
value = self.get_config()
if not ConfigExpression.is_expression(value):
return None
optional_module = self._parse_import_string(value[len(self.prefix) :])
if optional_module is not None:
return optional_module
if not self.run_eval:
return f"{value[len(self.prefix) :]}"
return eval(value[len(self.prefix) :], self.globals, locals)
Expand All @@ -337,3 +367,19 @@ def is_expression(cls, config: Union[Dict, List, str]) -> bool:
"""
return isinstance(config, str) and config.startswith(cls.prefix)

@classmethod
def is_import_statement(cls, config: Union[Dict, List, str]) -> bool:
"""
Check whether the config is an import statement (a special case of expression).
Args:
config: input config content to check.
"""
if not cls.is_expression(config):
return False
if "import" not in config:
return False
return isinstance(
first(ast.iter_child_nodes(ast.parse(f"{config[len(cls.prefix) :]}"))), (ast.Import, ast.ImportFrom)
)
15 changes: 9 additions & 6 deletions monai/bundle/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
import json
import re
from copy import deepcopy
Expand All @@ -26,6 +25,8 @@

__all__ = ["ConfigParser"]

_default_globals = {"monai": "monai", "torch": "torch", "np": "numpy", "numpy": "numpy"}


class ConfigParser:
"""
Expand Down Expand Up @@ -74,7 +75,7 @@ class ConfigParser:
so that expressions, for example, ``"$monai.data.list_data_collate"`` can use ``monai`` modules.
The current supported globals and alias names are
``{"monai": "monai", "torch": "torch", "np": "numpy", "numpy": "numpy"}``.
These are MONAI's minimal dependencies.
These are MONAI's minimal dependencies. Additional packages could be included with `globals={"itk": "itk"}`.
See also:
Expand All @@ -96,10 +97,12 @@ def __init__(
):
self.config = None
self.globals: Dict[str, Any] = {}
globals = {"monai": "monai", "torch": "torch", "np": "numpy", "numpy": "numpy"} if globals is None else globals
if globals is not None:
for k, v in globals.items():
self.globals[k] = importlib.import_module(v) if isinstance(v, str) else v
_globals = _default_globals.copy()
if isinstance(_globals, dict) and globals is not None:
_globals.update(globals)
if _globals is not None:
for k, v in _globals.items():
self.globals[k] = optional_import(v)[0] if isinstance(v, str) else v

self.locator = ComponentLocator(excludes=excludes)
self.ref_resolver = ReferenceResolver()
Expand Down
21 changes: 14 additions & 7 deletions monai/bundle/reference_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def get_item(self, id: str, resolve: bool = False, **kwargs):

def _resolve_one_item(self, id: str, waiting_list: Optional[Set[str]] = None, **kwargs):
"""
Resolve one ``ConfigItem`` of ``id``, cache the resolved result in ``resolved_content``.
Resolve and return one ``ConfigItem`` of ``id``, cache the resolved result in ``resolved_content``.
If it has unresolved references, recursively resolve the referring items first.
Args:
Expand All @@ -111,6 +111,8 @@ def _resolve_one_item(self, id: str, waiting_list: Optional[Set[str]] = None, **
Currently support ``instantiate`` and ``eval_expr``. Both are defaulting to True.
"""
if id in self.resolved_content:
return self.resolved_content[id]
try:
item = look_up_option(id, self.items, print_all_options=False)
except ValueError as err:
Expand All @@ -121,8 +123,14 @@ def _resolve_one_item(self, id: str, waiting_list: Optional[Set[str]] = None, **
waiting_list = set()
waiting_list.add(id)

ref_ids = self.find_refs_in_config(config=item_config, id=id)
for d in ref_ids:
for t, v in self.items.items():
if (
t not in self.resolved_content
and isinstance(v, ConfigExpression)
and v.is_import_statement(v.get_config())
):
self.resolved_content[t] = v.evaluate() if kwargs.get("eval_expr", True) else v
for d in self.find_refs_in_config(config=item_config, id=id):
# if current item has reference already in the waiting list, that's circular references
if d in waiting_list:
raise ValueError(f"detected circular references '{d}' for id='{id}' in the config content.")
Expand Down Expand Up @@ -150,20 +158,19 @@ def _resolve_one_item(self, id: str, waiting_list: Optional[Set[str]] = None, **
)
else:
self.resolved_content[id] = new_config
return self.resolved_content[id]

def get_resolved_content(self, id: str, **kwargs):
"""
Get the resolved ``ConfigItem`` by id. If there are unresolved references, try to resolve them first.
Get the resolved ``ConfigItem`` by id.
Args:
id: id name of the expected item.
kwargs: additional keyword arguments to be passed to ``_resolve_one_item``.
Currently support ``instantiate`` and ``eval_expr``. Both are defaulting to True.
"""
if id not in self.resolved_content:
self._resolve_one_item(id=id, **kwargs)
return self.resolved_content[id]
return self._resolve_one_item(id=id, **kwargs)

@classmethod
def match_refs_pattern(cls, value: str) -> Set[str]:
Expand Down
2 changes: 1 addition & 1 deletion monai/bundle/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def verify_net_in_out(
if it is a list of file paths, the content of them will be merged.
device: target device to run the network forward computation, if None, prefer to "cuda" if existing.
p: power factor to generate fake data shape if dim of expected shape is "x**p", default to 1.
p: multiply factor to generate fake data shape if dim of expected shape is "x*n", default to 1.
n: multiply factor to generate fake data shape if dim of expected shape is "x*n", default to 1.
any: specified size to generate fake data shape if dim of expected shape is "*", default to 1.
args_file: a JSON or YAML file to provide default values for `meta_file`, `config_file`,
`net_id` and override pairs. so that the command line inputs can be simplified.
Expand Down
5 changes: 4 additions & 1 deletion tests/test_bundle_verify_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ def test_verify(self, meta_file, config_file):
cmd = [sys.executable, "-m", "monai.bundle", "verify_net_in_out", "network_def", "--meta_file", meta_file]
cmd += ["--config_file", config_file, "-n", "2", "--any", "32", "--args_file", def_args_file]
cmd += ["--_meta_#network_data_format#inputs#image#spatial_shape", "[32,'*','4**p*n']"]
ret = subprocess.check_call(cmd)

test_env = os.environ.copy()
print(f"CUDA_VISIBLE_DEVICES in {__file__}", test_env.get("CUDA_VISIBLE_DEVICES"))
ret = subprocess.check_call(cmd, env=test_env)
self.assertEqual(ret, 0)


Expand Down
28 changes: 28 additions & 0 deletions tests/test_config_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,34 @@ def test_lazy_instantiation(self):
self.assertTrue(isinstance(ret, DataLoader))
self.assertEqual(ret.batch_size, 4)

@parameterized.expand([("$import json", "json"), ("$import json as j", "j")])
def test_import(self, stmt, mod_name):
test_globals = {}
ConfigExpression(id="", config=stmt, globals=test_globals).evaluate()
self.assertTrue(callable(test_globals[mod_name].dump))

@parameterized.expand(
[
("$from json import dump", "dump"),
("$from json import dump, dumps", "dump"),
("$from json import dump as jd", "jd"),
("$from json import dump as jd, dumps as ds", "jd"),
]
)
def test_import_from(self, stmt, mod_name):
test_globals = {}
ConfigExpression(id="", config=stmt, globals=test_globals).evaluate()
self.assertTrue(callable(test_globals[mod_name]))
self.assertTrue(ConfigExpression.is_import_statement(ConfigExpression(id="", config=stmt).config))

@parameterized.expand(
[("$from json import dump", True), ("$print()", False), ("$import json", True), ("import json", False)]
)
def test_is_import_stmt(self, stmt, expected):
expr = ConfigExpression(id="", config=stmt)
flag = expr.is_import_statement(expr.config)
self.assertEqual(flag, expected)


if __name__ == "__main__":
unittest.main()
10 changes: 10 additions & 0 deletions tests/testing_data/inference.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
{
"dataset_dir": "/workspace/data/Task09_Spleen",
"import_glob": "$import glob",
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
"set_seed": "$monai.utils.set_determinism(0)",
"print_test_name": "$print('json_test')",
"print_glob_file": "$print(glob.__file__)",
"network_def": {
"_target_": "UNet",
"spatial_dims": 3,
Expand Down Expand Up @@ -94,6 +98,12 @@
},
"evaluator": {
"_target_": "SupervisedEvaluator",
"_requires_": [
"@set_seed",
"@print_test_name",
"@print_glob_file",
"$print('test_in_line_json')"
],
"device": "@device",
"val_data_loader": "@dataloader",
"network": "@network",
Expand Down
6 changes: 6 additions & 0 deletions tests/testing_data/inference.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
---
dataset_dir: "/workspace/data/Task09_Spleen"
device: "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')"
set_seed: "$monai.utils.set_determinism(0)"
print_test_name: "$print('yaml_test')"
network_def:
_target_: UNet
spatial_dims: 3
Expand Down Expand Up @@ -67,6 +69,10 @@ postprocessing:
output_dir: "@_meta_#output_dir"
evaluator:
_target_: SupervisedEvaluator
_requires_:
- "$print('test_in_line_yaml')"
- "@set_seed"
- "@print_test_name"
device: "@device"
val_data_loader: "@dataloader"
network: "@network"
Expand Down

0 comments on commit ec19406

Please sign in to comment.