Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential solution for #6595 #6598

Draft
wants to merge 7 commits into
base: dev
Choose a base branch
from
Draft
64 changes: 56 additions & 8 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from __future__ import annotations

import inspect
import itertools
import random
import warnings
Expand Down Expand Up @@ -1744,22 +1745,69 @@ def scale_affine(spatial_size, new_spatial_size, centered: bool = True):
def attach_hook(func, hook, mode="pre"):
"""
Adds `hook` before or after a `func` call. If mode is "pre", the wrapper will call hook then func.
If the mode is "post", the wrapper will call func then hook.
If the mode is "post", the wrapper will call func then hook. In the case that additional arguments are
passed with the function 'func', the hook function will be called with any of those additional arguments
that the hook also supports. Such additional arguments must have the same name on both functions to be
matched. Unmatched arguments on 'func' are ignored when calling 'hook'.
"""
supported = {"pre", "post"}
if look_up_option(mode, supported) == "pre":
_hook, _func = hook, func
else:
_hook, _func = func, hook
_mode = look_up_option(mode, supported)

def key_in_args(args, k):
return any(k == a for a in args.args)

def index_of_key(args, k):
return args.args.index(k)

def param_has_default(args, k):
if args.defaults is None:
return False
return index_of_key(args, k) >= len(args.args) - len(args.defaults)

def param_default(args, k):
if args.defaults is None:
raise ValueError(f"Parameter {k} has no default")
d_k = len(args.args) - index_of_key(args, k) - 1
if d_k >= len(args.defaults):
raise ValueError(f"Parameter {k} has no default")
return args.defaults[d_k]

def key_at_index(args, i):
return args.args[i]

f_args = inspect.getfullargspec(func)
h_args = inspect.getfullargspec(hook)

@wraps(func)
def wrapper(inst, data):
data = _hook(inst, data)
return _func(inst, data)
def wrapper(inst, data, *args, **kwargs):
h_kwargs = dict()

# iterate over the positional args that the wrapper was called with, getting their names.
# add any values for parameter names that are also in the hook function's names
for i_a, a in enumerate(args[2:]):
k = key_at_index(f_args, i_a)
if key_in_args(h_args, k):
h_kwargs[k] = a

# go over parameters in the keyword args, adding any values for parameter names that are also in the hook function's names
for k, v in kwargs.items():
if key_in_args(h_args, k):
h_kwargs[k] = v

# handle the corner case where there is a parameter without a default on _hook that has a default on _func, but that hasn't
# been set by the caller. In this case, we get the default for that parameter on _func and pass it to _hook
for k in h_args.args:
if param_has_default(h_args, k) is False and k not in h_kwargs and param_has_default(f_args, k) is True:
h_kwargs[k] = param_default(f_args, k)

if _mode == "pre":
return func(inst, hook(inst, data, **h_kwargs), *args, **kwargs)
return hook(inst, func(inst, data, *args, **kwargs), **h_kwargs)

return wrapper



def sync_meta_info(key, data_dict, t: bool = True):
"""
Given the key, sync up between metatensor `data_dict[key]` and meta_dict `data_dict[key_transforms/meta_dict]`.
Expand Down