Skip to content

Commit

Permalink
map_classes_to_indices() ClassesToIndices() update (#6468)
Browse files Browse the repository at this point in the history
map_classes_to_indices utility (and ClassesToIndices transforms)
currently returns a list of coordinates for each class, but the type is
List[ of MetaTensors], where each class coordinate sub-list is its own
MetaTensor. This PR changes it to return a list of torch.Tensors (or
ndarray), since we don't need a MetaTensor here.

I ran into an issue with current MetaTensor list, where it would just
freeze without any errors, when trying to save the cached indices
(returned from ClassesToIndices to cache, which is ListProxy shared
mem). It randomly happens, but much more frequently when the number of
classes is large (e.g. 105 output classes, so ClassesToIndices returns a
list of 105 MetaTensors). I'm not sure what the cause of the freeze is,
but my guess is that ListProxy tries to pickle each element of this list
(and struggles with MetaTensors). Disabling MetaTensor return type here,
solves the issue.

since we don't really need meta tensor return type for each class
(coordinates), this seems like harmless fix. the original image/label
always will have meta info.

@wyli plz check.


### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: myron <[email protected]>
  • Loading branch information
myron authored May 5, 2023
1 parent 25c9c39 commit c2a9a31
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,10 @@ def map_classes_to_indices(
if img_flat is not None:
label_flat = img_flat & label_flat
# no need to save the indices in GPU, otherwise, still need to move to CPU at runtime when crop by indices
cls_indices: NdarrayOrTensor = convert_data_type(nonzero(label_flat), device=torch.device("cpu"))[0]
output_type = torch.Tensor if isinstance(label, monai.data.MetaTensor) else None
cls_indices: NdarrayOrTensor = convert_data_type(
nonzero(label_flat), output_type=output_type, device=torch.device("cpu")
)[0]
if max_samples_per_class and len(cls_indices) > max_samples_per_class and len(cls_indices) > 1:
sample_id = np.round(np.linspace(0, len(cls_indices) - 1, max_samples_per_class)).astype(int)
indices.append(cls_indices[sample_id])
Expand Down

0 comments on commit c2a9a31

Please sign in to comment.