Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] 3482 Add detection for circular macro replacement #4029

Open
wants to merge 7 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 19 additions & 5 deletions monai/bundle/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import re
from copy import deepcopy
from pathlib import Path
from typing import Any, Dict, Optional, Sequence, Tuple, Union
from typing import Any, Dict, Optional, Sequence, Set, Tuple, Union

from monai.bundle.config_item import ComponentLocator, ConfigComponent, ConfigExpression, ConfigItem
from monai.bundle.reference_resolver import ReferenceResolver
Expand Down Expand Up @@ -253,7 +253,7 @@ def read_config(self, f: Union[PathLike, Sequence[PathLike], Dict], **kwargs):
content.update(self.load_config_files(f, **kwargs))
self.set(config=content)

def _do_resolve(self, config: Any, id: str = ""):
def _do_resolve(self, config: Any, id: str = "", waiting_list: Optional[Set[str]] = None):
"""
Recursively resolve `self.config` to replace the relative ids with absolute ids, for example,
`@##A` means `A` in the upper level. and replace the macro tokens with target content,
Expand All @@ -266,18 +266,32 @@ def _do_resolve(self, config: Any, id: str = ""):
go one level further into the nested structures.
Use digits indexing from "0" for list or other strings for dict.
For example: ``"xform#5"``, ``"net#channels"``. ``""`` indicates the entire ``self.config``.
waiting_list: set of macro replacement ids pending to be resolved.
It's used to detect circular references such as:
`{"A": {"dep": "%B"}, "B": {"dep": "%A"}}`.

"""
if waiting_list is None:
waiting_list = set()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If waiting_list is actually a list we could look at the previous item in the exception to help the user trace what's happened.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's an interesting idea, I think the previous implementation in ReferenceResolver is a list and @wyli simplified it to set later. @wyli What do you think about it?

Thanks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think set is more suitable/efficient for in checks here...not sure if the previous item is very intuitive when the circular ref is formed by multiple reference a->b->...->c->a

if isinstance(config, (dict, list)):
for k, v in enumerate(config) if isinstance(config, list) else config.items():
sub_id = f"{id}{ID_SEP_KEY}{k}" if id != "" else k
config[k] = self._do_resolve(v, sub_id)
config[k] = self._do_resolve(v, sub_id, waiting_list)
if isinstance(config, str):
config = self.resolve_relative_ids(id, config)
if config.startswith(MACRO_KEY):
waiting_list.add(id)
path, ids = ConfigParser.split_path_id(config[len(MACRO_KEY) :])
parser = ConfigParser(config=self.get() if not path else ConfigParser.load_config_file(path))
return self._do_resolve(config=deepcopy(parser[ids]))
if not path:
# if the target id is in the waiting list, that's circular references
if ids in waiting_list:
raise ValueError(f"detected circular references in macro replacement '{ids}' for id='{id}'.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this would be raise ValueError(f"Detected circular references in macro replacement '{ids}' for id='{id}'. (previous id='{waiting_list[-2]}'") or something like that.

parser = ConfigParser(config=self.get())
config = self._do_resolve(deepcopy(parser[ids]), ids, waiting_list)
else:
# don't support recursive macro replacement in another config file
config = ConfigParser(config=ConfigParser.load_config_file(path))[ids]
waiting_list.discard(id)
return config

def resolve_macro_and_relative_ids(self):
Expand Down
2 changes: 1 addition & 1 deletion monai/bundle/reference_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def _resolve_one_item(self, id: str, waiting_list: Optional[Set[str]] = None, **
id: id name of ``ConfigItem`` to be resolved.
waiting_list: set of ids pending to be resolved.
It's used to detect circular references such as:
`{"name": "A", "dep": "@B"}` and `{"name": "B", "dep": "@A"}`.
`{"A": {"dep": "@B"}, "B": {"dep": "@A"}}`.
kwargs: keyword arguments to pass to ``_resolve_one_item()``.
Currently support ``instantiate`` and ``eval_expr``. Both are defaulting to True.

Expand Down
18 changes: 18 additions & 0 deletions tests/test_config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
import unittest
from unittest import skipUnless

Expand Down Expand Up @@ -142,6 +144,22 @@ def test_relative_id(self, config):
if isinstance(item, dict):
self.assertEqual(str(item), str({"key": 1, "value1": 2, "value2": 2, "value3": [3, 4, 4, 105]}))

def test_macro_replace(self):
with tempfile.TemporaryDirectory() as tempdir:
another_file = os.path.join(tempdir, "another.json")
ConfigParser.export_config_file(config={"E": 4}, filepath=another_file)
# test relative id, recursive macro replacement, and macro in another file
config = {"A": {"B": 1, "C": 2}, "D": [3, "%A#B", "%#1", f"%{another_file}#E"]}
parser = ConfigParser(config=config)
parser.resolve_macro_and_relative_ids()
self.assertEqual(str(parser.get()), str({"A": {"B": 1, "C": 2}, "D": [3, 1, 1, 4]}))

def test_circular_macro_replace(self):
config = {"A": "%B", "B": {"args": [1, 2, "%A"]}}
parser = ConfigParser(config=config)
with self.assertRaises(ValueError):
parser.resolve_macro_and_relative_ids()


if __name__ == "__main__":
unittest.main()