Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Move to numpy.Generators #6899

Draft
wants to merge 7 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions monai/apps/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def __init__(
self.section = section
self.val_frac = val_frac
self.test_frac = test_frac
self.set_random_state(seed=seed)
self.set_random_generator(seed=seed)
tarfile_name = root_dir / self.compressed_file_name
dataset_dir = root_dir / self.dataset_folder_name
self.num_class = 0
Expand Down Expand Up @@ -306,7 +306,7 @@ def __init__(
raise ValueError("Root directory root_dir must be a directory.")
self.section = section
self.val_frac = val_frac
self.set_random_state(seed=seed)
self.set_random_generator(seed=seed)
if task not in self.resource:
raise ValueError(f"Unsupported task: {task}, available options are: {list(self.resource.keys())}.")
dataset_dir = root_dir / task
Expand Down Expand Up @@ -530,7 +530,7 @@ def __init__(
self.ref_series_uid_tag = ref_series_uid_tag
self.ref_sop_uid_tag = ref_sop_uid_tag

self.set_random_state(seed=seed)
self.set_random_generator(seed=seed)
download_dir = os.path.join(root_dir, collection)
load_tags = list(specific_tags)
load_tags += [modality_tag]
Expand Down
21 changes: 13 additions & 8 deletions monai/apps/detection/transforms/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from monai.utils import InterpolateMode, NumpyPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple
from monai.utils.enums import PostFix, TraceKeys
from monai.utils.type_conversion import convert_data_type, convert_to_tensor
from monai.utils.utils_random_generator_adaptor import SupportsRandomGeneration

__all__ = [
"StandardizeEmptyBoxd",
Expand Down Expand Up @@ -566,9 +567,11 @@ def __init__(
self.align_corners = ensure_tuple_rep(align_corners, len(self.image_keys))
self.keep_size = keep_size

def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandZoomBoxd:
super().set_random_state(seed, state)
self.rand_zoom.set_random_state(seed, state)
def set_random_generator(
self, seed: int | None = None, generator: SupportsRandomGeneration | None = None
) -> Randomizable:
super().set_random_generator(seed, generator=generator)
self.rand_zoom.set_random_generator(seed, generator=generator)
return self

def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:
Expand Down Expand Up @@ -735,8 +738,10 @@ def __init__(
self.flipper = Flip(spatial_axis=spatial_axis)
self.box_flipper = FlipBox(spatial_axis=spatial_axis)

def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> RandFlipBoxd:
super().set_random_state(seed, state)
def set_random_generator(
self, seed: int | None = None, generator: SupportsRandomGeneration | None = None
) -> RandFlipBoxd:
super().set_random_generator(seed, generator)
return self

def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:
Expand Down Expand Up @@ -1177,8 +1182,8 @@ def randomize( # type: ignore
image_size,
fg_indices_,
bg_indices_,
self.R,
self.allow_smaller,
allow_smaller=self.allow_smaller,
generator=self.R,
)

def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> list[dict[Hashable, torch.Tensor]]:
Expand Down Expand Up @@ -1371,7 +1376,7 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Mapping[Hashable, t
return d

def randomize(self, data: Any | None = None) -> None:
self._rand_k = self.R.randint(self.max_k) + 1
self._rand_k = self.R.integers(self.max_k) + 1
super().randomize(None)

def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:
Expand Down
6 changes: 3 additions & 3 deletions monai/apps/nuclick/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def _seed_point(self, label):
indices = np.argwhere(convert_to_numpy(label) > 0)

if len(indices) > 0:
index = self.R.randint(0, len(indices))
index = self.R.integers(0, len(indices))
return indices[index, 0], indices[index, 1]
return None

Expand Down Expand Up @@ -382,8 +382,8 @@ def exclusion_map(self, others, dtype, jitter_range, drop_rate):
x = int(math.floor(x))
y = int(math.floor(y))
if jitter_range:
x = x + self.R.randint(low=-jitter_range, high=jitter_range)
y = y + self.R.randint(low=-jitter_range, high=jitter_range)
x = x + self.R.integers(low=-jitter_range, high=jitter_range)
y = y + self.R.integers(low=-jitter_range, high=jitter_range)
x = min(max(0, x), max_x)
y = min(max(0, y), max_y)
point_mask[x, y] = 1
Expand Down
4 changes: 2 additions & 2 deletions monai/apps/reconstruction/transforms/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def randomize_choose_acceleration(self) -> Sequence[float]:
lines to exclude from under-sampling
(2) acceleration: chosen acceleration factor
"""
choice = self.R.randint(0, len(self.accelerations))
choice = self.R.integers(0, len(self.accelerations))
center_fraction = self.center_fractions[choice]
acceleration = self.accelerations[choice]
return center_fraction, acceleration
Expand Down Expand Up @@ -257,7 +257,7 @@ def __call__(self, kspace: NdarrayOrTensor) -> Sequence[Tensor]:
# Determine acceleration rate by adjusting for the
# number of low frequencies
adjusted_accel = (acceleration * (num_low_freqs - num_cols)) / (num_low_freqs * acceleration - num_cols)
offset = self.R.randint(0, round(adjusted_accel))
offset = self.R.integers(0, round(adjusted_accel))

accel_samples = np.arange(offset, num_cols - 1, adjusted_accel)
accel_samples = np.around(accel_samples).astype(np.uint)
Expand Down
17 changes: 9 additions & 8 deletions monai/apps/reconstruction/transforms/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from monai.transforms.transform import MapTransform, RandomizableTransform
from monai.utils import FastMRIKeys
from monai.utils.type_conversion import convert_to_tensor
from monai.utils.utils_random_generator_adaptor import SupportsRandomGeneration


class ExtractDataKeyFromMetaKeyd(MapTransform):
Expand Down Expand Up @@ -114,11 +115,11 @@ def __init__(
is_complex=is_complex,
)

def set_random_state(
self, seed: int | None = None, state: np.random.RandomState | None = None
def set_random_generator(
self, seed: int | None = None, generator: SupportsRandomGeneration | None = None
) -> RandomKspaceMaskd:
super().set_random_state(seed, state)
self.masker.set_random_state(seed, state)
super().set_random_generator(seed, generator)
self.masker.set_random_generator(seed, generator)
return self

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, Tensor]:
Expand Down Expand Up @@ -182,11 +183,11 @@ def __init__(
is_complex=is_complex,
)

def set_random_state(
self, seed: int | None = None, state: np.random.RandomState | None = None
def set_random_generator(
self, seed: int | None = None, generator: SupportsRandomGeneration | None = None
) -> EquispacedKspaceMaskd:
super().set_random_state(seed, state)
self.masker.set_random_state(seed, state)
super().set_random_generator(seed, generator)
self.masker.set_random_generator(seed, generator)
return self


Expand Down
7 changes: 6 additions & 1 deletion monai/config/type_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from __future__ import annotations

import os
from typing import Collection, Hashable, Iterable, Sequence, TypeVar, Union
from typing import Collection, Hashable, Iterable, Sequence, SupportsIndex, Tuple, TypeVar, Union

import numpy as np
import torch
Expand Down Expand Up @@ -83,3 +83,8 @@
#: SequenceStr
# string or a sequence of strings for `mode` types.
SequenceStr = Union[Sequence[str], str]

Shape = Tuple[int, ...]

# Anything that can be coerced to a shape tuple
ShapeLike = Union[SupportsIndex, Sequence[SupportsIndex]]
2 changes: 1 addition & 1 deletion monai/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class DataLoader(_TorchDataLoader):

class RandomDataset(torch.utils.data.Dataset, Randomizable):
def __getitem__(self, index):
return self.R.randint(0, 1000, (1,))
return self.R.integers(0, 1000, (1,))

def __len__(self):
return 16
Expand Down
10 changes: 5 additions & 5 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,7 +1020,7 @@ def __init__(
runtime_cache=False,
) -> None:
if shuffle:
self.set_random_state(seed=seed)
self.set_random_generator(seed=seed)
self.shuffle = shuffle

self._start_pos: int = 0
Expand Down Expand Up @@ -1354,7 +1354,7 @@ def __init__(

"""
items = [(img, img_transform), (seg, seg_transform), (labels, label_transform)]
self.set_random_state(seed=get_seed())
self.set_random_generator(seed=get_seed())
datasets = [Dataset(x[0], x[1]) for x in items if x[0] is not None]
self.dataset = datasets[0] if len(datasets) == 1 else ZipDataset(datasets)

Expand All @@ -1364,7 +1364,7 @@ def __len__(self) -> int:
return len(self.dataset)

def randomize(self, data: Any | None = None) -> None:
self._seed = self.R.randint(MAX_SEED, dtype="uint32")
self._seed = self.R.integers(MAX_SEED, dtype="uint32")

def __getitem__(self, index: int):
self.randomize()
Expand All @@ -1373,10 +1373,10 @@ def __getitem__(self, index: int):
for dataset in self.dataset.data:
transform = getattr(dataset, "transform", None)
if isinstance(transform, Randomizable):
transform.set_random_state(seed=self._seed)
transform.set_random_generator(seed=self._seed)
transform = getattr(self.dataset, "transform", None)
if isinstance(transform, Randomizable):
transform.set_random_state(seed=self._seed)
transform.set_random_generator(seed=self._seed)
return self.dataset[index]


Expand Down
8 changes: 4 additions & 4 deletions monai/data/image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,14 @@ def __init__(
self.image_only = image_only
self.transform_with_metadata = transform_with_metadata
self.loader = LoadImage(reader, image_only, dtype, *args, **kwargs)
self.set_random_state(seed=get_seed())
self.set_random_generator(seed=get_seed())
self._seed = 0 # transform synchronization seed

def __len__(self) -> int:
return len(self.image_files)

def randomize(self, data: Any | None = None) -> None:
self._seed = self.R.randint(MAX_SEED, dtype="uint32")
self._seed = self.R.integers(MAX_SEED, dtype="uint32")

def __getitem__(self, index: int):
self.randomize()
Expand All @@ -116,7 +116,7 @@ def __getitem__(self, index: int):
# apply the transforms
if self.transform is not None:
if isinstance(self.transform, Randomizable):
self.transform.set_random_state(seed=self._seed)
self.transform.set_random_generator(seed=self._seed)

if self.transform_with_metadata:
img, meta_data = apply_transform(self.transform, (img, meta_data), map_items=False, unpack_items=True)
Expand All @@ -125,7 +125,7 @@ def __getitem__(self, index: int):

if self.seg_files is not None and self.seg_transform is not None:
if isinstance(self.seg_transform, Randomizable):
self.seg_transform.set_random_state(seed=self._seed)
self.seg_transform.set_random_generator(seed=self._seed)

if self.transform_with_metadata:
seg, seg_meta_data = apply_transform(
Expand Down
4 changes: 2 additions & 2 deletions monai/data/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,12 @@ def __iter__(self):
Multiple dataloader workers sharing this dataset will generate identical item sequences.
"""
self.seed += 1
super().set_random_state(seed=self.seed) # make all workers in sync
super().set_random_generator(seed=self.seed) # make all workers in sync
for _ in range(self.epochs) if self.epochs >= 0 else iter(int, 1):
yield from IterableDataset(self.generate_item(), transform=self.transform)

def randomize(self, size: int) -> None:
self._idx = self.R.randint(size)
self._idx = self.R.integers(size)


class CSVIterableDataset(IterableDataset):
Expand Down
2 changes: 1 addition & 1 deletion monai/data/test_time_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class TestTimeAugmentation:

model = UNet(...).to(device)
transform = Compose([RandAffined(keys, ...), ...])
transform.set_random_state(seed=123) # ensure deterministic evaluation
transform.set_random_generator(seed=123) # ensure deterministic evaluation

tt_aug = TestTimeAugmentation(
transform, batch_size=5, num_workers=0, inferrer_fn=model, device=device
Expand Down
21 changes: 16 additions & 5 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
look_up_option,
optional_import,
)
from monai.utils.deprecate_utils import deprecated_arg
from monai.utils.utils_random_generator_adaptor import SupportsRandomGeneration, _handle_legacy_random_state

pd, _ = optional_import("pandas")
DataFrame, _ = optional_import("pandas", name="DataFrame")
Expand Down Expand Up @@ -104,8 +106,14 @@
AFFINE_TOL = 1e-3


@deprecated_arg(
"rand_state", since="1.3.0", removed="1.5.0", new_name="generator", msg_suffix="Please use `generator` instead."
)
def get_random_patch(
dims: Sequence[int], patch_size: Sequence[int], rand_state: np.random.RandomState | None = None
dims: Sequence[int],
patch_size: Sequence[int],
rand_state: np.random.RandomState | None = None,
generator: SupportsRandomGeneration | None = None,
) -> tuple[slice, ...]:
"""
Returns a tuple of slices to define a random patch in an array of shape `dims` with size `patch_size` or the as
Expand All @@ -121,9 +129,12 @@ def get_random_patch(
(tuple of slice): a tuple of slice objects defining the patch
"""

generator = _handle_legacy_random_state(
rand_state=rand_state, generator=generator, return_legacy_default_random=True
)

# choose the minimal corner of the patch
rand_int = np.random.randint if rand_state is None else rand_state.randint
min_corner = tuple(rand_int(0, ms - ps + 1) if ms > ps else 0 for ms, ps in zip(dims, patch_size))
min_corner = tuple(generator.integers(0, ms - ps + 1) if ms > ps else 0 for ms, ps in zip(dims, patch_size))

# create the slices for each dimension which define the patch in the source array
return tuple(slice(mc, mc + ps) for mc, ps in zip(min_corner, patch_size))
Expand Down Expand Up @@ -703,8 +714,8 @@ def set_rnd(obj, seed: int) -> int:
return seed if _seed == seed else seed + 1 # return a different seed if there are randomizable items
if not hasattr(obj, "__dict__"):
return seed # no attribute
if hasattr(obj, "set_random_state"):
obj.set_random_state(seed=seed % MAX_SEED)
if hasattr(obj, "set_random_generator"):
obj.set_random_generator(seed=seed % MAX_SEED)
return seed + 1 # a different seed for the next component
for key in obj.__dict__:
if key.startswith("__"): # skip the private methods
Expand Down
4 changes: 2 additions & 2 deletions monai/data/wsi_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def __init__(
**kwargs,
)
self.overlap = overlap
self.set_random_state(seed)
self.set_random_generator(seed)
# Set the offset config
self.random_offset = False
if isinstance(offset, str):
Expand Down Expand Up @@ -281,7 +281,7 @@ def _get_offset(self, sample):
offset_limits = tuple((-s, s) for s in self._get_size(sample))
else:
offset_limits = self.offset_limits
return tuple(self.R.randint(low, high) for low, high in offset_limits)
return tuple(self.R.integers(low, high) for low, high in offset_limits)
return self.offset

def _evaluate_patch_locations(self, sample):
Expand Down
13 changes: 8 additions & 5 deletions monai/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
apply_transform,
)
from monai.utils import MAX_SEED, TraceKeys, TraceStatusKeys, ensure_tuple, get_seed
from monai.utils.utils_random_generator_adaptor import SupportsRandomGeneration

logger = get_logger(__name__)

Expand Down Expand Up @@ -248,19 +249,21 @@ def __init__(
self.map_items = map_items
self.unpack_items = unpack_items
self.log_stats = log_stats
self.set_random_state(seed=get_seed())
self.set_random_generator(seed=get_seed())
self.overrides = overrides

@LazyTransform.lazy.setter # type: ignore
def lazy(self, val: bool):
self._lazy = val

def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> Compose:
super().set_random_state(seed=seed, state=state)
def set_random_generator(
self, seed: int | None = None, generator: SupportsRandomGeneration | None = None
) -> Compose:
super().set_random_generator(seed=seed, generator=generator)
for _transform in self.transforms:
if not isinstance(_transform, Randomizable):
continue
_transform.set_random_state(seed=self.R.randint(MAX_SEED, dtype="uint32"))
_transform.set_random_generator(seed=self.R.integers(MAX_SEED, dtype="uint32"))
return self

def randomize(self, data: Any | None = None) -> None:
Expand Down Expand Up @@ -731,7 +734,7 @@ def __call__(self, data, start=0, end=None, threading=False, lazy: bool | None =
if len(self.transforms) == 0:
return data

sample_size = self.R.randint(self.min_num_transforms, self.max_num_transforms + 1)
sample_size = self.R.integers(self.min_num_transforms, self.max_num_transforms + 1)
applied_order = self.R.choice(len(self.transforms), sample_size, replace=self.replace, p=self.weights).tolist()
_lazy = self._lazy if lazy is None else lazy

Expand Down
Loading