Skip to content

Commit

Permalink
Make Auto3DSeg algo able to find algorithm templates (#6436)
Browse files Browse the repository at this point in the history
Fixes #6435 .

### Description

In this PR, the changes are
- Make `template_path` a property that's available for all Algo classes
to find the path to instantiate the class. The default value of
`template_path` is None.
- Enhance `algo_from_pickle` function by making it search in a few
candidate directories to instantiate the `Algo` class if needed.
- Remove checking if an algo instance is `BundleAlgo` in HPO logics
because `template_path` is universal for all `Algo` now.
- Update out-of-date docstring for AutoRunner

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [x] New tests added to cover the changes.
- [x] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [x] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Mingxin Zheng <[email protected]>
  • Loading branch information
mingxin-zheng authored Apr 27, 2023
1 parent 4f29172 commit 5f344cc
Show file tree
Hide file tree
Showing 9 changed files with 223 additions and 86 deletions.
1 change: 1 addition & 0 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ jobs:
run: |
pwd && git log -1 && which python
./runtests.sh -b
python -m tests.test_auto3dseg_bundlegen
python -m tests.test_auto3dseg_ensemble
python -m tests.test_auto3dseg_hpo
python -m tests.test_integration_autorunner
Expand Down
31 changes: 9 additions & 22 deletions monai/apps/auto3dseg/auto_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class AutoRunner:
.. code-block:: python
work_dir = "./work_dir"
input = "path_to_yaml_data_cfg"
input = "path/to/input_yaml"
runner = AutoRunner(work_dir=work_dir, input=input)
runner.run()
Expand All @@ -116,7 +116,7 @@ class AutoRunner:
.. code-block:: python
work_dir = "./work_dir"
input = "path_to_yaml_data_cfg"
input = "path/to/input_yaml"
algos = ["segresnet", "dints"]
runner = AutoRunner(work_dir=work_dir, input=input, algos=algos)
runner.run()
Expand All @@ -126,7 +126,7 @@ class AutoRunner:
.. code-block:: python
work_dir = "./work_dir"
input = "path_to_yaml_data_cfg"
input = "path/to/input_yaml"
algos = "segresnet"
templates_path_or_url = "./local_path_to/algorithm_templates"
runner = AutoRunner(work_dir=work_dir, input=input, algos=algos, templates_path_or_url=templates_path_or_url)
Expand All @@ -136,12 +136,10 @@ class AutoRunner:
.. code-block:: python
input = "path_to_yaml_data_cfg"
input = "path/to/input_yaml"
runner = AutoRunner(input=input)
train_param = {
"CUDA_VISIBLE_DEVICES": [0],
"num_iterations": 8,
"num_iterations_per_validation": 4,
"num_epochs_per_validation": 1,
"num_images_per_batch": 2,
"num_epochs": 2,
}
Expand All @@ -152,7 +150,7 @@ class AutoRunner:
.. code-block:: python
input = "path_to_yaml_data_cfg"
input = "path/to/input_yaml"
runner = AutoRunner(input=input)
runner.set_num_fold(n_fold = 2)
runner.run()
Expand All @@ -161,7 +159,7 @@ class AutoRunner:
.. code-block:: python
input = "path_to_yaml_data_cfg"
input = "path/to/input_yaml"
pred_params = {
'files_slices': slice(0,2),
'mode': "vote",
Expand All @@ -175,14 +173,7 @@ class AutoRunner:
.. code-block:: python
input = "path_to_yaml_data_cfg"
pred_param = {
"CUDA_VISIBLE_DEVICES": [0],
"num_iterations": 8,
"num_iterations_per_validation": 4,
"num_images_per_batch": 2,
"num_epochs": 2,
}
input = "path/to/input_yaml"
runner = AutoRunner(input=input, hpo=True)
runner.set_nni_search_space({"learning_rate": {"_type": "choice", "_value": [0.0001, 0.001, 0.01, 0.1]}})
runner.run()
Expand Down Expand Up @@ -471,7 +462,7 @@ def set_training_params(self, params: dict[str, Any] | None = None) -> None:
Examples:
For BundleAlgo objects, the training parameter to shorten the training time to a few epochs can be
{"num_iterations": 8, "num_iterations_per_validation": 4}
{"num_epochs": 2, "num_epochs_per_validation": 1}
"""
self.train_params = deepcopy(params) if params is not None else {}
Expand Down Expand Up @@ -596,10 +587,6 @@ def set_analyze_params(self, params: dict[str, Any] | None = None) -> None:
params: a dict that defines the overriding key-value pairs during training. The overriding method
is defined by the algo class.
Examples:
For BundleAlgo objects, the training parameter to shorten the training time to a few epochs can be
{"num_iterations": 8, "num_iterations_per_validation": 4}
"""
if params is None:
self.analyze_params = {"do_ccp": False, "device": "cuda"}
Expand Down
7 changes: 4 additions & 3 deletions monai/apps/auto3dseg/bundle_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from monai.auto3dseg.algo_gen import Algo, AlgoGen
from monai.auto3dseg.utils import algo_to_pickle
from monai.bundle.config_parser import ConfigParser
from monai.config import PathLike
from monai.utils import ensure_tuple
from monai.utils.enums import AlgoKeys

Expand Down Expand Up @@ -63,7 +64,7 @@ class BundleAlgo(Algo):
"""

def __init__(self, template_path: str):
def __init__(self, template_path: PathLike):
"""
Create an Algo instance based on the predefined Algo template.
Expand Down Expand Up @@ -153,9 +154,9 @@ def export_to_disk(self, output_path: str, algo_name: str, **kwargs: Any) -> Non
os.makedirs(self.output_path, exist_ok=True)
if os.path.isdir(self.output_path):
shutil.rmtree(self.output_path)
shutil.copytree(self.template_path, self.output_path)
shutil.copytree(str(self.template_path), self.output_path)
else:
self.output_path = self.template_path
self.output_path = str(self.template_path)
if kwargs.pop("fill_template", True):
self.fill_records = self.fill_template_config(self.data_stats_files, self.output_path, **kwargs)
logger.info(f"Generated:{self.output_path}")
Expand Down
28 changes: 4 additions & 24 deletions monai/apps/auto3dseg/hpo_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,7 @@ def __init__(self, algo: Algo | None = None, params: dict | None = None):
else:
self.algo = algo

if isinstance(self.algo, BundleAlgo):
self.obj_filename = algo_to_pickle(self.algo, template_path=self.algo.template_path)
else:
self.obj_filename = algo_to_pickle(self.algo)
# nni instruction unknown
self.obj_filename = algo_to_pickle(self.algo, template_path=self.algo.template_path)

def get_obj_filename(self):
"""Return the filename of the dumped pickle algo object."""
Expand Down Expand Up @@ -226,9 +222,6 @@ def run_algo(self, obj_filename: str, output_folder: str = ".", template_path: P

self.algo, algo_meta_data = algo_from_pickle(obj_filename, template_path=template_path)

if isinstance(self.algo, BundleAlgo): # algo's template path needs override
self.algo.template_path = algo_meta_data["template_path"]

# step 1 sample hyperparams
params = self.get_hyperparameters()
# step 2 set the update params for the algo to run in the next trial
Expand All @@ -240,10 +233,7 @@ def run_algo(self, obj_filename: str, output_folder: str = ".", template_path: P
acc = self.algo.get_score()
algo_meta_data = {str(AlgoKeys.SCORE): acc}

if isinstance(self.algo, BundleAlgo):
algo_to_pickle(self.algo, template_path=self.algo.template_path, **algo_meta_data)
else:
algo_to_pickle(self.algo, **algo_meta_data)
algo_to_pickle(self.algo, template_path=self.algo.template_path, **algo_meta_data)
self.set_score(acc)


Expand Down Expand Up @@ -304,11 +294,7 @@ def __init__(self, algo: Algo | None = None, params: dict | None = None) -> None
else:
self.algo = algo

if isinstance(self.algo, BundleAlgo):
self.obj_filename = algo_to_pickle(self.algo, template_path=self.algo.template_path)
else:
self.obj_filename = algo_to_pickle(self.algo)
# nni instruction unknown
self.obj_filename = algo_to_pickle(self.algo, template_path=self.algo.template_path)

def get_obj_filename(self):
"""Return the dumped pickle object of algo."""
Expand Down Expand Up @@ -399,9 +385,6 @@ def run_algo(self, obj_filename: str, output_folder: str = ".", template_path: P

self.algo, algo_meta_data = algo_from_pickle(obj_filename, template_path=template_path)

if isinstance(self.algo, BundleAlgo): # algo's template path needs override
self.algo.template_path = algo_meta_data["template_path"]

# step 1 sample hyperparams
params = self.get_hyperparameters()
# step 2 set the update params for the algo to run in the next trial
Expand All @@ -412,8 +395,5 @@ def run_algo(self, obj_filename: str, output_folder: str = ".", template_path: P
# step 4 report validation acc to controller
acc = self.algo.get_score()
algo_meta_data = {str(AlgoKeys.SCORE): acc}
if isinstance(self.algo, BundleAlgo):
algo_to_pickle(self.algo, template_path=self.algo.template_path, **algo_meta_data)
else:
algo_to_pickle(self.algo, **algo_meta_data)
algo_to_pickle(self.algo, template_path=self.algo.template_path, **algo_meta_data)
self.set_score(acc)
3 changes: 0 additions & 3 deletions monai/apps/auto3dseg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,6 @@ def import_bundle_algo_history(

algo, algo_meta_data = algo_from_pickle(obj_filename, template_path=template_path)

if isinstance(algo, BundleAlgo): # algo's template path needs override
algo.template_path = algo_meta_data["template_path"]

best_metric = algo_meta_data.get(AlgoKeys.SCORE, None)
if best_metric is None:
try:
Expand Down
3 changes: 3 additions & 0 deletions monai/auto3dseg/algo_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from __future__ import annotations

from monai.config import PathLike
from monai.transforms import Randomizable


Expand All @@ -20,6 +21,8 @@ class Algo:
such as image preprocessing, followed by deep learning model training and evaluation.
"""

template_path: PathLike | None = None

def set_data_stats(self, *args, **kwargs):
"""Provide dataset (and summaries) so that the model creation can depend on the input datasets."""
pass
Expand Down
89 changes: 55 additions & 34 deletions monai/auto3dseg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@

from __future__ import annotations

import logging
import os
import pickle
import sys
import warnings
from copy import deepcopy
from numbers import Number
from typing import Any, cast
Expand All @@ -25,6 +25,7 @@
from monai.auto3dseg import Algo
from monai.bundle.config_parser import ConfigParser
from monai.bundle.utils import ID_SEP_KEY
from monai.config import PathLike
from monai.data.meta_tensor import MetaTensor
from monai.transforms import CropForeground, ToCupy
from monai.utils import min_version, optional_import
Expand Down Expand Up @@ -272,20 +273,20 @@ def verify_report_format(report: dict, report_format: dict) -> bool:
return True


def algo_to_pickle(algo: Algo, **algo_meta_data: Any) -> str:
def algo_to_pickle(algo: Algo, template_path: PathLike | None = None, **algo_meta_data: Any) -> str:
"""
Export the Algo object to pickle file
Export the Algo object to pickle file.
Args:
algo: Algo-like object
algo_meta_data: additional keyword to save into the dictionary. It may include template_path
which is used to instantiate the class. It may also include model training info
algo: Algo-like object.
template_path: a str path that is needed to be added to the sys.path to instantiate the class.
algo_meta_data: additional keyword to save into the dictionary, for example, model training info
such as acc/best_metrics
Returns:
filename of the pickled Algo object
"""
data = {"algo_bytes": pickle.dumps(algo)}
data = {"algo_bytes": pickle.dumps(algo), "template_path": str(template_path)}
pkl_filename = os.path.join(algo.get_output_path(), "algo_object.pkl")
for k, v in algo_meta_data.items():
data.update({k: v})
Expand All @@ -295,22 +296,24 @@ def algo_to_pickle(algo: Algo, **algo_meta_data: Any) -> str:
return pkl_filename


def algo_from_pickle(pkl_filename: str, **kwargs: Any) -> Any:
def algo_from_pickle(pkl_filename: str, template_path: PathLike | None = None, **kwargs: Any) -> Any:
"""
Import the Algo object from a pickle file
Import the Algo object from a pickle file.
Args:
pkl_filename: name of the pickle file
algo_templates_dir: the algorithm script folder which is needed to instantiate the object.
If it is None, the function will use the internal ``'algo_templates_dir`` in the object
dict.
pkl_filename: the name of the pickle file.
template_path: a folder containing files to instantiate the Algo. Besides the `template_path`,
this function will also attempt to use the `template_path` saved in the pickle file and a directory
named `algorithm_templates` in the parent folder of the folder containing the pickle file.
Returns:
algo: Algo-like object
algo: the Algo object saved in the pickle file.
algo_meta_data: additional keyword saved in the pickle file, for example, acc/best_metrics.
Raises:
ValueError if the pkl_filename does not contain a dict, or the dict does not contain
``template_path`` or ``algo_bytes``
ValueError if the pkl_filename does not contain a dict, or the dict does not contain `algo_bytes`.
ModuleNotFoundError if it is unable to instiante the Algo class.
"""
with open(pkl_filename, "rb") as f_pi:
data_bytes = f_pi.read()
Expand All @@ -323,30 +326,48 @@ def algo_from_pickle(pkl_filename: str, **kwargs: Any) -> Any:
raise ValueError(f"key [algo_bytes] not found in {data}. Unable to instantiate.")

algo_bytes = data.pop("algo_bytes")
algo_meta_data = {}
algo_template_path = data.pop("template_path", None)

if "template_path" in kwargs: # add template_path to sys.path
template_path = kwargs["template_path"]
if template_path is None: # then load template_path from pickled data
if "template_path" not in data:
raise ValueError(f"key [template_path] not found in {data}")
template_path = data.pop("template_path")
template_paths_candidates: list[str] = []

if not os.path.isdir(template_path):
raise ValueError(f"Algorithm templates {template_path} is not a directory")
# Example of template path: "algorithm_templates/dints".
sys.path.insert(0, os.path.abspath(os.path.join(template_path, "..")))
algo_meta_data.update({"template_path": template_path})
if os.path.isdir(str(template_path)):
template_paths_candidates.append(os.path.abspath(str(template_path)))
template_paths_candidates.append(os.path.abspath(os.path.join(str(template_path), "..")))

if os.path.isdir(str(algo_template_path)):
template_paths_candidates.append(os.path.abspath(algo_template_path))
template_paths_candidates.append(os.path.abspath(os.path.join(algo_template_path, "..")))

algo = pickle.loads(algo_bytes)
pkl_dir = os.path.dirname(pkl_filename)
if pkl_dir != algo.get_output_path():
warnings.warn(
f"{algo.get_output_path()} does not contain {pkl_filename}."
f"Now override the Algo output_path with: {pkl_dir}"
)
algo_template_path_fuzzy = os.path.join(pkl_dir, "..", "algorithm_templates")

if os.path.isdir(algo_template_path_fuzzy):
template_paths_candidates.append(os.path.abspath(algo_template_path_fuzzy))

if len(template_paths_candidates) == 0:
# no template_path provided or needed
algo = pickle.loads(algo_bytes)
algo.template_path = None
else:
for i, p in enumerate(template_paths_candidates):
try:
sys.path.append(p)
algo = pickle.loads(algo_bytes)
break
except ModuleNotFoundError as not_found_err:
logging.debug(f"Folder {p} doesn't contain the Algo templates for Algo instantiation.")
sys.path.pop()
if i == len(template_paths_candidates) - 1:
raise ValueError(
f"Failed to instantiate {pkl_filename} with {template_paths_candidates}"
) from not_found_err
algo.template_path = p

if os.path.abspath(pkl_dir) != os.path.abspath(algo.get_output_path()):
logging.debug(f"{algo.get_output_path()} is changed. Now override the Algo output_path with: {pkl_dir}.")
algo.output_path = pkl_dir

algo_meta_data = {}
for k, v in data.items():
algo_meta_data.update({k: v})

Expand Down
1 change: 1 addition & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def run_testsuit():
exclude_cases = [ # these cases use external dependencies
"test_ahnet",
"test_arraydataset",
"test_auto3dseg_bundlegen",
"test_auto3dseg_ensemble",
"test_auto3dseg_hpo",
"test_auto3dseg",
Expand Down
Loading

0 comments on commit 5f344cc

Please sign in to comment.