From 29243d2446265fcc5cc431c4a73edede8689949b Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 15 Mar 2022 11:59:47 +0000 Subject: [PATCH 1/8] fixes #3942 Signed-off-by: Wenqi Li --- monai/bundle/config_item.py | 17 +++++++++++------ monai/bundle/config_parser.py | 12 ++++++++---- tests/testing_data/inference.json | 7 +++++++ tests/testing_data/inference.yaml | 6 ++++++ 4 files changed, 32 insertions(+), 10 deletions(-) diff --git a/monai/bundle/config_item.py b/monai/bundle/config_item.py index 807b369f5d..90e07b902b 100644 --- a/monai/bundle/config_item.py +++ b/monai/bundle/config_item.py @@ -164,17 +164,22 @@ class ConfigComponent(ConfigItem, Instantiable): Subclass of :py:class:`monai.bundle.ConfigItem`, this class uses a dictionary with string keys to represent a component of `class` or `function` and supports instantiation. - Currently, two special keys (strings surrounded by ``_``) are defined and interpreted beyond the regular literals: - - - class or function identifier of the python module, specified by one of the two keys. - - ``"_target_"``: indicates build-in python classes or functions such as "LoadImageDict", - or full module name, such as "monai.transforms.LoadImageDict". + Currently, three special keys (strings surrounded by ``_``) are defined and interpreted beyond the regular literals: + + - class or function identifier of the python module, specified by ``"_target_"``, + indicating a build-in python class or function such as ``"LoadImageDict"``, + or a full module name, such as ``"monai.transforms.LoadImageDict"``. + - ``"_requires_"``: specifies reference IDs (string starts with ``"@"``) or ``ConfigExpression`` + of the dependencies for this ``ConfigComponent`` object. These dependencies will be + evaluated/instantiated before this object is instantiated. - ``"_disabled_"``: a flag to indicate whether to skip the instantiation. Other fields in the config content are input arguments to the python module. .. code-block:: python + from monai.bundle import ComponentLocator, ConfigComponent + locator = ComponentLocator(excludes=["modules_to_exclude"]) config = { "_target_": "LoadImaged", @@ -195,7 +200,7 @@ class ConfigComponent(ConfigItem, Instantiable): """ - non_arg_keys = {"_target_", "_disabled_"} + non_arg_keys = {"_target_", "_disabled_", "_requires_"} def __init__( self, diff --git a/monai/bundle/config_parser.py b/monai/bundle/config_parser.py index 6fa7b3a2a2..9a0325e9ab 100644 --- a/monai/bundle/config_parser.py +++ b/monai/bundle/config_parser.py @@ -26,6 +26,8 @@ __all__ = ["ConfigParser"] +_default_globals = {"monai": "monai", "torch": "torch", "np": "numpy", "numpy": "numpy"} + class ConfigParser: """ @@ -74,7 +76,7 @@ class ConfigParser: so that expressions, for example, ``"$monai.data.list_data_collate"`` can use ``monai`` modules. The current supported globals and alias names are ``{"monai": "monai", "torch": "torch", "np": "numpy", "numpy": "numpy"}``. - These are MONAI's minimal dependencies. + These are MONAI's minimal dependencies. Additional packages could be included with `globals={"itk": "itk"}`. See also: @@ -96,9 +98,11 @@ def __init__( ): self.config = None self.globals: Dict[str, Any] = {} - globals = {"monai": "monai", "torch": "torch", "np": "numpy", "numpy": "numpy"} if globals is None else globals - if globals is not None: - for k, v in globals.items(): + _globals = _default_globals.copy() + if isinstance(_globals, dict) and globals is not None: + _globals.update(globals) + if _globals is not None: + for k, v in _globals.items(): self.globals[k] = importlib.import_module(v) if isinstance(v, str) else v self.locator = ComponentLocator(excludes=excludes) diff --git a/tests/testing_data/inference.json b/tests/testing_data/inference.json index 6cc6de88ef..72e246dc98 100644 --- a/tests/testing_data/inference.json +++ b/tests/testing_data/inference.json @@ -1,5 +1,7 @@ { "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", + "set_seed": "$monai.utils.set_determinism(0)", + "print_test_name": "$print('json_test')", "network_def": { "_target_": "UNet", "spatial_dims": 3, @@ -93,6 +95,11 @@ }, "evaluator": { "_target_": "SupervisedEvaluator", + "_requires_": [ + "@set_seed", + "@print_test_name", + "$print('test_in_line_json')" + ], "device": "@device", "val_data_loader": "@dataloader", "network": "@network", diff --git a/tests/testing_data/inference.yaml b/tests/testing_data/inference.yaml index eb2870ee03..fcb2376c7f 100644 --- a/tests/testing_data/inference.yaml +++ b/tests/testing_data/inference.yaml @@ -1,5 +1,7 @@ --- device: "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')" +set_seed: "$monai.utils.set_determinism(0)" +print_test_name: "$print('yaml_test')" network_def: _target_: UNet spatial_dims: 3 @@ -66,6 +68,10 @@ postprocessing: output_dir: "@_meta_#output_dir" evaluator: _target_: SupervisedEvaluator + _requires_: + - "$print('test_in_line_yaml')" + - "@set_seed" + - "@print_test_name" device: "@device" val_data_loader: "@dataloader" network: "@network" From a1df0c4e847f22bb719a1a0240c24a7033de271b Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 17 Mar 2022 14:44:42 +0000 Subject: [PATCH 2/8] adds import statement Signed-off-by: Wenqi Li --- monai/bundle/config_item.py | 29 ++++++++++++++++++++++++++--- monai/bundle/config_parser.py | 3 +-- tests/test_config_item.py | 6 ++++++ tests/testing_data/inference.json | 3 +++ 4 files changed, 36 insertions(+), 5 deletions(-) diff --git a/monai/bundle/config_item.py b/monai/bundle/config_item.py index 90e07b902b..47c1980897 100644 --- a/monai/bundle/config_item.py +++ b/monai/bundle/config_item.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import ast import inspect import os import sys @@ -18,7 +19,7 @@ from typing import Any, Dict, List, Mapping, Optional, Sequence, Union from monai.bundle.utils import EXPR_KEY -from monai.utils import ensure_tuple, instantiate +from monai.utils import ensure_tuple, instantiate, optional_import __all__ = ["ComponentLocator", "ConfigItem", "ConfigExpression", "ConfigComponent"] @@ -284,7 +285,7 @@ def instantiate(self, **kwargs) -> object: # type: ignore class ConfigExpression(ConfigItem): """ Subclass of :py:class:`monai.bundle.ConfigItem`, the `ConfigItem` represents an executable expression - (execute based on ``eval()``). + (execute based on ``eval()``, or import the module to the `globals` if it's an import statement). See also: @@ -313,7 +314,26 @@ class ConfigExpression(ConfigItem): def __init__(self, config: Any, id: str = "", globals: Optional[Dict] = None) -> None: super().__init__(config=config, id=id) - self.globals = globals + self.globals = globals if globals is not None else {} + + def _parse_import_string(self, import_string: str): + # parse single import statement such as "from monai.transforms import Resize" + for n in ast.iter_child_nodes(ast.parse(import_string)): + if not isinstance(n, (ast.Import, ast.ImportFrom)): + return None + if len(n.names) < 1: + return None + if len(n.names) > 1: + warnings.warn(f"ignoring multiple import alias '{import_string}'.") + name, asname = f"{n.names[0].name}", n.names[0].asname + asname = name if asname is None else f"{asname}" + if isinstance(n, ast.ImportFrom): + self.globals[asname], _ = optional_import(f"{n.module}", name=f"{name}") + return self.globals[asname] + elif isinstance(n, ast.Import): + self.globals[asname], _ = optional_import(f"{name}") + return self.globals[asname] + return None def evaluate(self, locals: Optional[Dict] = None): """ @@ -327,6 +347,9 @@ def evaluate(self, locals: Optional[Dict] = None): value = self.get_config() if not ConfigExpression.is_expression(value): return None + optional_module = self._parse_import_string(value[len(self.prefix) :]) + if optional_module is not None: + return optional_module if not self.run_eval: return f"{value[len(self.prefix) :]}" return eval(value[len(self.prefix) :], self.globals, locals) diff --git a/monai/bundle/config_parser.py b/monai/bundle/config_parser.py index 9a0325e9ab..23d4ac7c55 100644 --- a/monai/bundle/config_parser.py +++ b/monai/bundle/config_parser.py @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import importlib import json import re from copy import deepcopy @@ -103,7 +102,7 @@ def __init__( _globals.update(globals) if _globals is not None: for k, v in _globals.items(): - self.globals[k] = importlib.import_module(v) if isinstance(v, str) else v + self.globals[k] = optional_import(v)[0] if isinstance(v, str) else v self.locator = ComponentLocator(excludes=excludes) self.ref_resolver = ReferenceResolver() diff --git a/tests/test_config_item.py b/tests/test_config_item.py index 7b43cd30ea..48be60e6f4 100644 --- a/tests/test_config_item.py +++ b/tests/test_config_item.py @@ -91,6 +91,12 @@ def test_lazy_instantiation(self): self.assertTrue(isinstance(ret, DataLoader)) self.assertEqual(ret.batch_size, 4) + def test_import(self): + import_string = "$import json" + test_globals = {} + ConfigExpression(id="", config=import_string, globals=test_globals).evaluate() + self.assertTrue(callable(test_globals["json"].dump)) + if __name__ == "__main__": unittest.main() diff --git a/tests/testing_data/inference.json b/tests/testing_data/inference.json index 932b9018e1..4c29a4e55b 100644 --- a/tests/testing_data/inference.json +++ b/tests/testing_data/inference.json @@ -1,7 +1,9 @@ { + "import_glob": "$import glob", "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", "set_seed": "$monai.utils.set_determinism(0)", "print_test_name": "$print('json_test')", + "print_glob_file": "$print(@import_glob.__file__)", "network_def": { "_target_": "UNet", "spatial_dims": 3, @@ -98,6 +100,7 @@ "_requires_": [ "@set_seed", "@print_test_name", + "@print_glob_file", "$print('test_in_line_json')" ], "device": "@device", From 8f1edacae122d23c950177194ad4e04e0c358b8b Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 17 Mar 2022 15:02:17 +0000 Subject: [PATCH 3/8] more tests Signed-off-by: Wenqi Li --- tests/test_config_item.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/tests/test_config_item.py b/tests/test_config_item.py index 48be60e6f4..83e214e369 100644 --- a/tests/test_config_item.py +++ b/tests/test_config_item.py @@ -91,11 +91,24 @@ def test_lazy_instantiation(self): self.assertTrue(isinstance(ret, DataLoader)) self.assertEqual(ret.batch_size, 4) - def test_import(self): - import_string = "$import json" + @parameterized.expand([("$import json", "json"), ("$import json as j", "j")]) + def test_import(self, stmt, mod_name): test_globals = {} - ConfigExpression(id="", config=import_string, globals=test_globals).evaluate() - self.assertTrue(callable(test_globals["json"].dump)) + ConfigExpression(id="", config=stmt, globals=test_globals).evaluate() + self.assertTrue(callable(test_globals[mod_name].dump)) + + @parameterized.expand( + [ + ("$from json import dump", "dump"), + ("$from json import dump, dumps", "dump"), + ("$from json import dump as jd", "jd"), + ("$from json import dump as jd, dumps as ds", "jd"), + ] + ) + def test_import_from(self, stmt, mod_name): + test_globals = {} + ConfigExpression(id="", config=stmt, globals=test_globals).evaluate() + self.assertTrue(callable(test_globals[mod_name])) if __name__ == "__main__": From 3a763f9ac6f761a8b7f90c37d54e037ba70ce463 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 18 Mar 2022 12:31:55 +0000 Subject: [PATCH 4/8] update to resolve import statement Signed-off-by: Wenqi Li --- monai/bundle/config_item.py | 50 ++++++++++++++++++++---------- monai/bundle/reference_resolver.py | 18 ++++++++--- tests/testing_data/inference.json | 2 +- 3 files changed, 47 insertions(+), 23 deletions(-) diff --git a/monai/bundle/config_item.py b/monai/bundle/config_item.py index 47c1980897..4062e783e2 100644 --- a/monai/bundle/config_item.py +++ b/monai/bundle/config_item.py @@ -19,7 +19,7 @@ from typing import Any, Dict, List, Mapping, Optional, Sequence, Union from monai.bundle.utils import EXPR_KEY -from monai.utils import ensure_tuple, instantiate, optional_import +from monai.utils import ensure_tuple, first, instantiate, optional_import __all__ = ["ComponentLocator", "ConfigItem", "ConfigExpression", "ConfigComponent"] @@ -317,22 +317,22 @@ def __init__(self, config: Any, id: str = "", globals: Optional[Dict] = None) -> self.globals = globals if globals is not None else {} def _parse_import_string(self, import_string: str): - # parse single import statement such as "from monai.transforms import Resize" - for n in ast.iter_child_nodes(ast.parse(import_string)): - if not isinstance(n, (ast.Import, ast.ImportFrom)): - return None - if len(n.names) < 1: - return None - if len(n.names) > 1: - warnings.warn(f"ignoring multiple import alias '{import_string}'.") - name, asname = f"{n.names[0].name}", n.names[0].asname - asname = name if asname is None else f"{asname}" - if isinstance(n, ast.ImportFrom): - self.globals[asname], _ = optional_import(f"{n.module}", name=f"{name}") - return self.globals[asname] - elif isinstance(n, ast.Import): - self.globals[asname], _ = optional_import(f"{name}") - return self.globals[asname] + """parse single import statement such as "from monai.transforms import Resize""" + node = first(ast.iter_child_nodes(ast.parse(import_string))) + if not isinstance(node, (ast.Import, ast.ImportFrom)): + return None + if len(node.names) < 1: + return None + if len(node.names) > 1: + warnings.warn(f"ignoring multiple import alias '{import_string}'.") + name, asname = f"{node.names[0].name}", node.names[0].asname + asname = name if asname is None else f"{asname}" + if isinstance(node, ast.ImportFrom): + self.globals[asname], _ = optional_import(f"{node.module}", name=f"{name}") + return self.globals[asname] + if isinstance(node, ast.Import): + self.globals[asname], _ = optional_import(f"{name}") + return self.globals[asname] return None def evaluate(self, locals: Optional[Dict] = None): @@ -365,3 +365,19 @@ def is_expression(cls, config: Union[Dict, List, str]) -> bool: """ return isinstance(config, str) and config.startswith(cls.prefix) + + @classmethod + def is_import_statement(cls, config: Union[Dict, List, str]) -> bool: + """ + Check whether the config is an import statement (a special case of expression). + + Args: + config: input config content to check. + """ + if not cls.is_expression(config): + return False + if "import" not in config: + return False + return isinstance( + first(ast.iter_child_nodes(ast.parse(f"{config[len(cls.prefix) :]}"))), (ast.Import, ast.ImportFrom) + ) diff --git a/monai/bundle/reference_resolver.py b/monai/bundle/reference_resolver.py index c1599c2124..23f1b1964d 100644 --- a/monai/bundle/reference_resolver.py +++ b/monai/bundle/reference_resolver.py @@ -99,7 +99,7 @@ def get_item(self, id: str, resolve: bool = False, **kwargs): def _resolve_one_item(self, id: str, waiting_list: Optional[Set[str]] = None, **kwargs): """ - Resolve one ``ConfigItem`` of ``id``, cache the resolved result in ``resolved_content``. + Resolve and return one ``ConfigItem`` of ``id``, cache the resolved result in ``resolved_content``. If it has unresolved references, recursively resolve the referring items first. Args: @@ -111,6 +111,8 @@ def _resolve_one_item(self, id: str, waiting_list: Optional[Set[str]] = None, ** Currently support ``instantiate`` and ``eval_expr``. Both are defaulting to True. """ + if id in self.resolved_content: + return self.resolved_content[id] try: item = look_up_option(id, self.items, print_all_options=False) except ValueError as err: @@ -122,6 +124,13 @@ def _resolve_one_item(self, id: str, waiting_list: Optional[Set[str]] = None, ** waiting_list.add(id) ref_ids = self.find_refs_in_config(config=item_config, id=id) + for t, v in self.items.items(): + if ( + t not in self.resolved_content + and isinstance(v, ConfigExpression) + and v.is_import_statement(v.get_config()) + ): + self.resolved_content[t] = v.evaluate() if kwargs.get("eval_expr", True) else v for d in ref_ids: # if current item has reference already in the waiting list, that's circular references if d in waiting_list: @@ -150,10 +159,11 @@ def _resolve_one_item(self, id: str, waiting_list: Optional[Set[str]] = None, ** ) else: self.resolved_content[id] = new_config + return self.resolved_content[id] def get_resolved_content(self, id: str, **kwargs): """ - Get the resolved ``ConfigItem`` by id. If there are unresolved references, try to resolve them first. + Get the resolved ``ConfigItem`` by id. Args: id: id name of the expected item. @@ -161,9 +171,7 @@ def get_resolved_content(self, id: str, **kwargs): Currently support ``instantiate`` and ``eval_expr``. Both are defaulting to True. """ - if id not in self.resolved_content: - self._resolve_one_item(id=id, **kwargs) - return self.resolved_content[id] + return self._resolve_one_item(id=id, **kwargs) @classmethod def match_refs_pattern(cls, value: str) -> Set[str]: diff --git a/tests/testing_data/inference.json b/tests/testing_data/inference.json index 68672cc420..cc9ddef866 100644 --- a/tests/testing_data/inference.json +++ b/tests/testing_data/inference.json @@ -4,7 +4,7 @@ "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", "set_seed": "$monai.utils.set_determinism(0)", "print_test_name": "$print('json_test')", - "print_glob_file": "$print(@import_glob.__file__)", + "print_glob_file": "$print(glob.__file__)", "network_def": { "_target_": "UNet", "spatial_dims": 3, From 3305071acf0994d2547c4e05ebf80f4fe559ffa4 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 18 Mar 2022 12:44:50 +0000 Subject: [PATCH 5/8] update based on comments Signed-off-by: Wenqi Li --- monai/bundle/config_item.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/monai/bundle/config_item.py b/monai/bundle/config_item.py index 4062e783e2..3840a54e70 100644 --- a/monai/bundle/config_item.py +++ b/monai/bundle/config_item.py @@ -170,12 +170,14 @@ class ConfigComponent(ConfigItem, Instantiable): - class or function identifier of the python module, specified by ``"_target_"``, indicating a build-in python class or function such as ``"LoadImageDict"``, or a full module name, such as ``"monai.transforms.LoadImageDict"``. - - ``"_requires_"``: specifies reference IDs (string starts with ``"@"``) or ``ConfigExpression`` + - ``"_requires_"`` (optional): specifies reference IDs (string starts with ``"@"``) or ``ConfigExpression`` of the dependencies for this ``ConfigComponent`` object. These dependencies will be evaluated/instantiated before this object is instantiated. - - ``"_disabled_"``: a flag to indicate whether to skip the instantiation. + - ``"_disabled_"`` (optional): a flag to indicate whether to skip the instantiation. - Other fields in the config content are input arguments to the python module. + Other fields in the config content are input arguments to the python module. ``"_requires_"`` is only useful when the + component doesn't explicitly depends on the other `ConfigItems` via its arguments, + but requires the dependencies to be instantiated/evaluated beforehand. .. code-block:: python From 9fc4d40f6556b495796411e9995c45522864944e Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 18 Mar 2022 13:05:04 +0000 Subject: [PATCH 6/8] update Signed-off-by: Wenqi Li --- monai/bundle/reference_resolver.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/monai/bundle/reference_resolver.py b/monai/bundle/reference_resolver.py index 23f1b1964d..f9f73c9c71 100644 --- a/monai/bundle/reference_resolver.py +++ b/monai/bundle/reference_resolver.py @@ -123,7 +123,6 @@ def _resolve_one_item(self, id: str, waiting_list: Optional[Set[str]] = None, ** waiting_list = set() waiting_list.add(id) - ref_ids = self.find_refs_in_config(config=item_config, id=id) for t, v in self.items.items(): if ( t not in self.resolved_content @@ -131,7 +130,7 @@ def _resolve_one_item(self, id: str, waiting_list: Optional[Set[str]] = None, ** and v.is_import_statement(v.get_config()) ): self.resolved_content[t] = v.evaluate() if kwargs.get("eval_expr", True) else v - for d in ref_ids: + for d in self.find_refs_in_config(config=item_config, id=id): # if current item has reference already in the waiting list, that's circular references if d in waiting_list: raise ValueError(f"detected circular references '{d}' for id='{id}' in the config content.") From 040157acf6e6bf59929d0be99147b202fccd289e Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Sun, 20 Mar 2022 14:48:02 +0000 Subject: [PATCH 7/8] update based on comments Signed-off-by: Wenqi Li --- monai/bundle/config_item.py | 8 ++++---- tests/test_config_item.py | 9 +++++++++ 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/monai/bundle/config_item.py b/monai/bundle/config_item.py index 3840a54e70..0531c6f14e 100644 --- a/monai/bundle/config_item.py +++ b/monai/bundle/config_item.py @@ -172,12 +172,12 @@ class ConfigComponent(ConfigItem, Instantiable): or a full module name, such as ``"monai.transforms.LoadImageDict"``. - ``"_requires_"`` (optional): specifies reference IDs (string starts with ``"@"``) or ``ConfigExpression`` of the dependencies for this ``ConfigComponent`` object. These dependencies will be - evaluated/instantiated before this object is instantiated. + evaluated/instantiated before this object is instantiated. It is useful when the + component doesn't explicitly depends on the other `ConfigItems` via its arguments, + but requires the dependencies to be instantiated/evaluated beforehand. - ``"_disabled_"`` (optional): a flag to indicate whether to skip the instantiation. - Other fields in the config content are input arguments to the python module. ``"_requires_"`` is only useful when the - component doesn't explicitly depends on the other `ConfigItems` via its arguments, - but requires the dependencies to be instantiated/evaluated beforehand. + Other fields in the config content are input arguments to the python module. .. code-block:: python diff --git a/tests/test_config_item.py b/tests/test_config_item.py index 83e214e369..fbd76e7be7 100644 --- a/tests/test_config_item.py +++ b/tests/test_config_item.py @@ -109,6 +109,15 @@ def test_import_from(self, stmt, mod_name): test_globals = {} ConfigExpression(id="", config=stmt, globals=test_globals).evaluate() self.assertTrue(callable(test_globals[mod_name])) + self.assertTrue(ConfigExpression.is_import_statement(ConfigExpression(id="", config=stmt).config)) + + @parameterized.expand( + [("$from json import dump", True), ("$print()", False), ("$import json", True), ("import json", False)] + ) + def test_is_import_stmt(self, stmt, expected): + expr = ConfigExpression(id="", config=stmt) + flag = expr.is_import_statement(expr.config) + self.assertEqual(flag, expected) if __name__ == "__main__": From b3ce6116273ac54ca4fd95c11ced4146e40447b0 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 21 Mar 2022 11:24:16 +0000 Subject: [PATCH 8/8] fixes unit tests Signed-off-by: Wenqi Li --- monai/bundle/scripts.py | 2 +- tests/test_bundle_verify_net.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/monai/bundle/scripts.py b/monai/bundle/scripts.py index 23f57df8ce..5bbde5fd62 100644 --- a/monai/bundle/scripts.py +++ b/monai/bundle/scripts.py @@ -260,7 +260,7 @@ def verify_net_in_out( if it is a list of file paths, the content of them will be merged. device: target device to run the network forward computation, if None, prefer to "cuda" if existing. p: power factor to generate fake data shape if dim of expected shape is "x**p", default to 1. - p: multiply factor to generate fake data shape if dim of expected shape is "x*n", default to 1. + n: multiply factor to generate fake data shape if dim of expected shape is "x*n", default to 1. any: specified size to generate fake data shape if dim of expected shape is "*", default to 1. args_file: a JSON or YAML file to provide default values for `meta_file`, `config_file`, `net_id` and override pairs. so that the command line inputs can be simplified. diff --git a/tests/test_bundle_verify_net.py b/tests/test_bundle_verify_net.py index c6aa6d61fb..62f99aab99 100644 --- a/tests/test_bundle_verify_net.py +++ b/tests/test_bundle_verify_net.py @@ -38,7 +38,10 @@ def test_verify(self, meta_file, config_file): cmd = [sys.executable, "-m", "monai.bundle", "verify_net_in_out", "network_def", "--meta_file", meta_file] cmd += ["--config_file", config_file, "-n", "2", "--any", "32", "--args_file", def_args_file] cmd += ["--_meta_#network_data_format#inputs#image#spatial_shape", "[32,'*','4**p*n']"] - ret = subprocess.check_call(cmd) + + test_env = os.environ.copy() + print(f"CUDA_VISIBLE_DEVICES in {__file__}", test_env.get("CUDA_VISIBLE_DEVICES")) + ret = subprocess.check_call(cmd, env=test_env) self.assertEqual(ret, 0)