From d17319401dad63251384083cd2b17f6481cc5380 Mon Sep 17 00:00:00 2001 From: Adeel Hassan Date: Wed, 8 Nov 2023 16:14:07 -0500 Subject: [PATCH] allow visualizers to plot predictions without ground truth (#1987) Co-authored-by: Adeel Hassan --- .../visualizer/classification_visualizer.py | 92 +++++++++++-------- .../visualizer/object_detection_visualizer.py | 15 +-- .../visualizer/regression_visualizer.py | 80 ++++++++++------ .../semantic_segmentation_visualizer.py | 59 +++++++----- .../dataset/visualizer/visualizer.py | 19 ++-- .../test_classification_visualizer.py | 18 ++++ .../test_object_detection_visualizer.py | 18 ++++ .../visualizer/test_regression_visualizer.py | 18 ++++ .../test_semantic_segmentation_visualizer.py | 20 ++++ 9 files changed, 238 insertions(+), 101 deletions(-) diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/classification_visualizer.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/classification_visualizer.py index 7b43e58aa..1700aebbd 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/classification_visualizer.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/classification_visualizer.py @@ -1,4 +1,4 @@ -from typing import (Sequence, Optional) +from typing import TYPE_CHECKING, Optional, Sequence from textwrap import wrap import torch @@ -7,14 +7,17 @@ from rastervision.pytorch_learner.utils import (plot_channel_groups, channel_groups_to_imgs) +if TYPE_CHECKING: + from matplotlib.pyplot import Axes + class ClassificationVisualizer(Visualizer): """Plots samples from image classification Datasets.""" def plot_xyz(self, - axs: Sequence, + axs: Sequence['Axes'], x: torch.Tensor, - y: int, + y: Optional[int] = None, z: Optional[int] = None, plot_title: bool = True) -> None: channel_groups = self.get_channel_display_groups(x.shape[1]) @@ -30,46 +33,61 @@ def plot_xyz(self, # plot label class_names = self.class_names class_names = ['-\n-'.join(wrap(c, width=16)) for c in class_names] - if z is None: - # just display the class name as text - class_name = class_names[y] - label_ax.text( - .5, - .5, - class_name, - ha='center', - va='center', - fontdict={ - 'size': 20, - 'family': 'sans-serif' - }) - label_ax.set_xlim((0, 1)) - label_ax.set_ylim((0, 1)) - label_ax.axis('off') - else: - # display predicted class probabilities as a horizontal bar plot - # legend: green = ground truth, dark-red = wrong prediction, - # light-gray = other. In case predicted class matches ground truth, - # only one bar will be green and the others will be light-gray. - class_probabilities = z.softmax(dim=-1) - class_index_pred = z.argmax(dim=-1) + if y is not None and z is None: + self.plot_gt(label_ax, class_names, y) + elif z is not None: + self.plot_pred(label_ax, class_names, z, y=y) + if plot_title: + label_ax.set_title('Prediction') + + def plot_gt(self, ax: 'Axes', class_names: Sequence[str], y: torch.Tensor): + """Display ground truth class names as text.""" + class_name = class_names[y] + ax.text( + x=.5, + y=.5, + s=class_name, + ha='center', + va='center', + fontdict={ + 'size': 20, + 'family': 'sans-serif' + }) + ax.set_xlim((0, 1)) + ax.set_ylim((0, 1)) + ax.axis('off') + + def plot_pred(self, + ax: 'Axes', + class_names: Sequence[str], + z: torch.Tensor, + y: Optional[torch.Tensor] = None): + """Plot predictions. + + Plots predicted class probabilities as a horizontal bar plot. If ground + truth, y, is provided, the bar colors represent: green = ground truth, + dark-red = wrong prediction, light-gray = other. In case predicted + class matches ground truth, only one bar will be green and the others + will be light-gray. + """ + class_probabilities = z.softmax(dim=-1) + class_index_pred = z.argmax(dim=-1) + bar_colors = ['lightgray'] * len(z) + if y is not None: class_index_gt = y - bar_colors = ['lightgray'] * len(z) if class_index_pred == class_index_gt: bar_colors[class_index_pred] = 'green' else: bar_colors[class_index_pred] = 'darkred' bar_colors[class_index_gt] = 'green' - label_ax.barh( - y=class_names, - width=class_probabilities, - color=bar_colors, - edgecolor='black') - label_ax.set_xlim((0, 1)) - label_ax.xaxis.grid(linestyle='--', alpha=1) - label_ax.set_xlabel('Probability') - if plot_title: - label_ax.set_title('Prediction') + ax.barh( + y=class_names, + width=class_probabilities, + color=bar_colors, + edgecolor='black') + ax.set_xlim((0, 1)) + ax.xaxis.grid(linestyle='--', alpha=1) + ax.set_xlabel('Probability') def get_plot_ncols(self, **kwargs) -> int: x = kwargs['x'] diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/object_detection_visualizer.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/object_detection_visualizer.py index e6f9fbe0b..eea6cb63b 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/object_detection_visualizer.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/object_detection_visualizer.py @@ -18,15 +18,18 @@ def get_collate_fn(self): def plot_xyz(self, axs: Sequence, x: torch.Tensor, - y: BoxList, + y: Optional[BoxList] = None, z: Optional[BoxList] = None, plot_title: bool = True) -> None: - y = y if z is None else z channel_groups = self.get_channel_display_groups(x.shape[1]) + imgs = channel_groups_to_imgs(x, channel_groups) - class_names = self.class_names - class_colors = self.class_colors + if y is not None or z is not None: + y = y if z is None else z + class_names = self.class_names + class_colors = self.class_colors + imgs = [ + draw_boxes(img, y, class_names, class_colors) for img in imgs + ] - imgs = channel_groups_to_imgs(x, channel_groups) - imgs = [draw_boxes(img, y, class_names, class_colors) for img in imgs] plot_channel_groups(axs, imgs, channel_groups, plot_title=plot_title) diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/regression_visualizer.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/regression_visualizer.py index 3dc076f49..f2ec7282b 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/regression_visualizer.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/regression_visualizer.py @@ -1,4 +1,4 @@ -from typing import (Sequence, Optional) +from typing import TYPE_CHECKING, Optional, Sequence from textwrap import wrap import torch @@ -9,6 +9,9 @@ from rastervision.pytorch_learner.utils import (plot_channel_groups, channel_groups_to_imgs) +if TYPE_CHECKING: + from matplotlib.pyplot import Axes + class RegressionVisualizer(Visualizer): """Plots samples from image regression Datasets.""" @@ -22,7 +25,7 @@ def plot_xyz(self, channel_groups = self.get_channel_display_groups(x.shape[1]) img_axes = axs[:-1] - label_ax = axs[-1] + label_ax: 'Axes' = axs[-1] # plot image imgs = channel_groups_to_imgs(x, channel_groups) @@ -32,44 +35,63 @@ def plot_xyz(self, # plot label class_names = self.class_names class_names = ['-\n-'.join(wrap(c, width=8)) for c in class_names] - if z is None: - # display targets as a horizontal bar plot - bars_gt = label_ax.barh( - y=class_names, width=y, color='lightgray', edgecolor='black') - # show values on the end of bars - label_ax.bar_label(bars_gt, fmt='%.3f', padding=3) + + if y is not None and z is None: + self.plot_gt(label_ax, class_names, y) if plot_title: label_ax.set_title('Ground truth') - else: - # display targets and predictions as a grouped horizontal bar plot - bar_thickness = 0.35 - y_tick_locs = np.arange(len(class_names)) - bars_gt = label_ax.barh( + elif z is not None: + self.plot_pred(label_ax, class_names, z, y=y) + + def plot_gt(self, ax: 'Axes', class_names: Sequence[str], y: torch.Tensor): + """Plot targets as a horizontal bar plot with values at the tips.""" + bars_gt = ax.barh( + y=class_names, width=y, color='lightgray', edgecolor='black') + # show values on the end of bars + ax.bar_label(bars_gt, fmt='%.3f', padding=3) + + ax.xaxis.grid(linestyle='--', alpha=1) + ax.set_xlabel('Value') + ax.spines['right'].set_visible(False) + ax.get_yaxis().tick_left() + + def plot_pred(self, + ax: 'Axes', + class_names: Sequence[str], + z: torch.Tensor, + y: Optional[torch.Tensor] = None): + """Plot targets and predictions as a grouped horizontal bar plot.""" + # display targets and predictions as a grouped horizontal bar plot + bar_thickness = 0.35 if y is not None else 0.70 + y_tick_locs = np.arange(len(class_names)) + if y is not None: + bars_gt = ax.barh( y=y_tick_locs + bar_thickness / 2, width=y, height=bar_thickness, color='lightgray', edgecolor='black', label='true') - bars_pred = label_ax.barh( - y=y_tick_locs - bar_thickness / 2, - width=z, - height=bar_thickness, - color=plt.get_cmap('tab10')(0), - edgecolor='black', - label='pred') # show values on the end of bars - label_ax.bar_label(bars_gt, fmt='%.3f', padding=3) - label_ax.bar_label(bars_pred, fmt='%.3f', padding=3) + ax.bar_label(bars_gt, fmt='%.3f', padding=3) + + bars_pred = ax.barh( + y=y_tick_locs - bar_thickness / 2, + width=z, + height=bar_thickness, + color=plt.get_cmap('tab10')(0), + edgecolor='black', + label='pred') + # show values on the end of bars + ax.bar_label(bars_pred, fmt='%.3f', padding=3) - label_ax.set_yticks(ticks=y_tick_locs, labels=class_names) - label_ax.legend( - ncol=2, loc='lower center', bbox_to_anchor=(0.5, 1.0)) + ax.set_yticks(ticks=y_tick_locs, labels=class_names) + ax.legend(ncol=2, loc='lower center', bbox_to_anchor=(0.5, 1.0)) - label_ax.xaxis.grid(linestyle='--', alpha=1) - label_ax.set_xlabel('Target value') - label_ax.spines['right'].set_visible(False) - label_ax.get_yaxis().tick_left() + ax.xaxis.grid(linestyle='--', alpha=1) + ax.set_xlabel('Value') + ax.spines['right'].set_visible(False) + ax.get_yaxis().tick_left() def get_plot_ncols(self, **kwargs) -> int: x = kwargs['x'] diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/semantic_segmentation_visualizer.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/semantic_segmentation_visualizer.py index 9febf7529..a45ef93ac 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/semantic_segmentation_visualizer.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/semantic_segmentation_visualizer.py @@ -1,4 +1,4 @@ -from typing import (Sequence, Optional, Union) +from typing import TYPE_CHECKING, Optional, Sequence, Union import torch import numpy as np @@ -9,6 +9,10 @@ from rastervision.pytorch_learner.utils import ( color_to_triple, plot_channel_groups, channel_groups_to_imgs) +if TYPE_CHECKING: + from matplotlib.pyplot import Axes + from matplotlib.colors import Colormap + class SemanticSegmentationVisualizer(Visualizer): """Plots samples from semantic segmentation Datasets.""" @@ -16,19 +20,21 @@ class SemanticSegmentationVisualizer(Visualizer): def plot_xyz(self, axs: Sequence, x: torch.Tensor, - y: Union[torch.Tensor, np.ndarray], + y: Optional[Union[torch.Tensor, np.ndarray]] = None, z: Optional[torch.Tensor] = None, plot_title: bool = True) -> None: channel_groups = self.get_channel_display_groups(x.shape[1]) img_axes = axs[:len(channel_groups)] - label_ax = axs[len(channel_groups)] # plot image imgs = channel_groups_to_imgs(x, channel_groups) plot_channel_groups( img_axes, imgs, channel_groups, plot_title=plot_title) + if y is None and z is None: + return + # plot labels class_colors = self.class_colors colors = [ @@ -38,27 +44,17 @@ def plot_xyz(self, colors = np.array(colors) / 255. cmap = mcolors.ListedColormap(colors) - label_ax.imshow( - y, vmin=0, vmax=len(colors), cmap=cmap, interpolation='none') - if plot_title: - label_ax.set_title(f'Ground truth') - label_ax.set_xticks([]) - label_ax.set_yticks([]) + if y is not None: + label_ax: 'Axes' = axs[len(channel_groups)] + self.plot_gt(label_ax, y, num_classes=len(colors), cmap=cmap) + if plot_title: + label_ax.set_title('Ground truth') - # plot predictions if z is not None: pred_ax = axs[-1] - preds = z.argmax(dim=0) - pred_ax.imshow( - preds, - vmin=0, - vmax=len(colors), - cmap=cmap, - interpolation='none') + self.plot_pred(pred_ax, z, num_classes=len(colors), cmap=cmap) if plot_title: - pred_ax.set_title(f'Predicted labels') - pred_ax.set_xticks([]) - pred_ax.set_yticks([]) + pred_ax.set_title('Predicted labels') # add a legend to the rightmost subplot class_names = self.class_names @@ -72,11 +68,30 @@ def plot_xyz(self, loc='center left', bbox_to_anchor=(1., 0.5)) + def plot_gt(self, ax: 'Axes', y: Union[torch.Tensor, np.ndarray], + num_classes: int, cmap: 'Colormap', **kwargs): + ax.imshow( + y, + vmin=0, + vmax=num_classes, + cmap=cmap, + interpolation='none', + **kwargs) + ax.set_xticks([]) + ax.set_yticks([]) + + def plot_pred(self, ax: 'Axes', z: Union[torch.Tensor, np.ndarray], + num_classes: int, cmap: 'Colormap', **kwargs): + if z.ndim == 3: + z = z.argmax(dim=0) + self.plot_gt(ax, y=z, num_classes=num_classes, cmap=cmap, **kwargs) + def get_plot_ncols(self, **kwargs) -> int: x = kwargs['x'] nb_img_channels = x.shape[1] ncols = len(self.get_channel_display_groups(nb_img_channels)) + 1 - z = kwargs.get('z') - if z is not None: + if kwargs.get('y') is not None: + ncols += 1 + if kwargs.get('z') is not None: ncols += 1 return ncols diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/visualizer.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/visualizer.py index ddf3ca62f..70c4adf6a 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/visualizer.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/visualizer.py @@ -5,8 +5,8 @@ import numpy as np import torch from torch import Tensor -import albumentations as A from torch.utils.data import DataLoader +import albumentations as A import matplotlib.pyplot as plt from rastervision.pipeline.file_system import make_dir @@ -34,7 +34,7 @@ class Visualizer(ABC): def __init__(self, class_names: List[str], class_colors: Optional[List[Union[str, RGBTuple]]] = None, - transform: Optional[Dict] = A.to_dict(MinMaxNormalize()), + transform: Optional[Dict] = None, channel_display_groups: Optional[Union[Dict[ str, ChannelInds], Sequence[ChannelInds]]] = None): """Constructor. @@ -62,6 +62,8 @@ def __init__(self, """ self.class_names = class_names self.class_colors = ensure_class_colors(self.class_names, class_colors) + if transform is None: + transform = A.to_dict(MinMaxNormalize()) self.transform = validate_albumentation_transform(transform) self._channel_display_groups = validate_channel_display_groups( channel_display_groups) @@ -70,7 +72,7 @@ def __init__(self, def plot_xyz(self, axs, x: Tensor, - y: Sequence, + y: Optional[Sequence] = None, z: Optional[Sequence] = None, plot_title: bool = True): """Plot image, ground truth labels, and predicted labels. @@ -84,7 +86,7 @@ def plot_xyz(self, def plot_batch(self, x: Tensor, - y: Sequence, + y: Optional[Sequence] = None, output_path: Optional[str] = None, z: Optional[Sequence] = None, batch_limit: Optional[int] = None, @@ -133,7 +135,7 @@ def plot_batch(self, for args in plot_xyz_args[1:]: args['plot_title'] = False _x = x[i] - _y = [y[i]] * T + _y = None if y is None else [y[i]] * T _z = None if z is None else [z[i]] * T self._plot_batch(fig, axs, plot_xyz_args, _x, y=_y, z=_z) else: @@ -166,9 +168,12 @@ def _plot_batch( imgs = [tf(image=img)['image'] for img in x.numpy()] x = torch.from_numpy(np.stack(imgs)) + if y is None: + y = [None] * len(x) + if z is None: + z = [None] * len(x) for i, row_axs in enumerate(axs): - _z = None if z is None else z[i] - self.plot_xyz(row_axs, x[i], y[i], z=_z, **plot_xyz_args[i]) + self.plot_xyz(row_axs, x[i], y[i], z=z[i], **plot_xyz_args[i]) def get_channel_display_groups( self, nb_img_channels: int diff --git a/tests/pytorch_learner/dataset/visualizer/test_classification_visualizer.py b/tests/pytorch_learner/dataset/visualizer/test_classification_visualizer.py index cb49cf406..8b758caf5 100644 --- a/tests/pytorch_learner/dataset/visualizer/test_classification_visualizer.py +++ b/tests/pytorch_learner/dataset/visualizer/test_classification_visualizer.py @@ -31,6 +31,15 @@ def test_plot_batch(self): z = torch.tensor([[0.9, 0.1], [0.6, 0.4]]) self.assertNoError(lambda: viz.plot_batch(x, y, z=z)) + # w/ z, w/o y + viz = ClassificationVisualizer( + class_names=['bg', 'fg'], + channel_display_groups=dict(RGB=[0, 1, 2], IR=[3])) + x = torch.randn(size=(2, 4, 256, 256)) + y = None + z = torch.tensor([[0.9, 0.1], [0.6, 0.4]]) + self.assertNoError(lambda: viz.plot_batch(x, y, z=z)) + def test_plot_batch_temporal(self): # w/o z viz = ClassificationVisualizer( @@ -48,3 +57,12 @@ def test_plot_batch_temporal(self): y = torch.tensor([0, 1]) z = torch.tensor([[0.9, 0.1], [0.6, 0.4]]) self.assertNoError(lambda: viz.plot_batch(x, y, z=z)) + + # w/ z, w/o y + viz = ClassificationVisualizer( + class_names=['bg', 'fg'], + channel_display_groups=dict(RGB=[0, 1, 2], IR=[3])) + x = torch.randn(size=(2, 3, 4, 256, 256)) + y = None + z = torch.tensor([[0.9, 0.1], [0.6, 0.4]]) + self.assertNoError(lambda: viz.plot_batch(x, y, z=z)) diff --git a/tests/pytorch_learner/dataset/visualizer/test_object_detection_visualizer.py b/tests/pytorch_learner/dataset/visualizer/test_object_detection_visualizer.py index 5ec6a362b..eb4ff0f6d 100644 --- a/tests/pytorch_learner/dataset/visualizer/test_object_detection_visualizer.py +++ b/tests/pytorch_learner/dataset/visualizer/test_object_detection_visualizer.py @@ -42,6 +42,15 @@ def test_plot_batch(self): z = [random_boxlist(_x) for _x in x] self.assertNoError(lambda: viz.plot_batch(x, y, z=z)) + # w/ z, w/o y + viz = ObjectDetectionVisualizer( + class_names=['bg', 'fg'], + channel_display_groups=dict(RGB=[0, 1, 2], IR=[3])) + x = torch.randn(size=(2, 4, 256, 256)) + y = None + z = [random_boxlist(_x) for _x in x] + self.assertNoError(lambda: viz.plot_batch(x, y, z=z)) + def test_plot_batch_temporal(self): # w/o z viz = ObjectDetectionVisualizer( @@ -59,3 +68,12 @@ def test_plot_batch_temporal(self): y = [random_boxlist(_x) for _x in x] z = [random_boxlist(_x) for _x in x] self.assertNoError(lambda: viz.plot_batch(x, y, z=z)) + + # w/ z, w/o y + viz = ObjectDetectionVisualizer( + class_names=['bg', 'fg'], + channel_display_groups=dict(RGB=[0, 1, 2], IR=[3])) + x = torch.randn(size=(2, 3, 4, 256, 256)) + y = None + z = [random_boxlist(_x) for _x in x] + self.assertNoError(lambda: viz.plot_batch(x, y, z=z)) diff --git a/tests/pytorch_learner/dataset/visualizer/test_regression_visualizer.py b/tests/pytorch_learner/dataset/visualizer/test_regression_visualizer.py index 80ed35715..b299db5f3 100644 --- a/tests/pytorch_learner/dataset/visualizer/test_regression_visualizer.py +++ b/tests/pytorch_learner/dataset/visualizer/test_regression_visualizer.py @@ -31,6 +31,15 @@ def test_plot_batch(self): z = torch.tensor([0.1, 2]) self.assertNoError(lambda: viz.plot_batch(x, y, z=z)) + # w/ z, w/o y + viz = RegressionVisualizer( + class_names=['bg', 'fg'], + channel_display_groups=dict(RGB=[0, 1, 2], IR=[3])) + x = torch.randn(size=(2, 4, 256, 256)) + y = None + z = torch.tensor([0.1, 2]) + self.assertNoError(lambda: viz.plot_batch(x, y, z=z)) + def test_plot_batch_temporal(self): # w/o z viz = RegressionVisualizer( @@ -48,3 +57,12 @@ def test_plot_batch_temporal(self): y = torch.tensor([0.2, 1.3]) z = torch.tensor([0.1, 2]) self.assertNoError(lambda: viz.plot_batch(x, y, z=z)) + + # w/ z, w/o y + viz = RegressionVisualizer( + class_names=['bg', 'fg'], + channel_display_groups=dict(RGB=[0, 1, 2], IR=[3])) + x = torch.randn(size=(2, 3, 4, 256, 256)) + y = None + z = torch.tensor([0.1, 2]) + self.assertNoError(lambda: viz.plot_batch(x, y, z=z)) diff --git a/tests/pytorch_learner/dataset/visualizer/test_semantic_segmentation_visualizer.py b/tests/pytorch_learner/dataset/visualizer/test_semantic_segmentation_visualizer.py index edd73fa81..6d48dc1a1 100644 --- a/tests/pytorch_learner/dataset/visualizer/test_semantic_segmentation_visualizer.py +++ b/tests/pytorch_learner/dataset/visualizer/test_semantic_segmentation_visualizer.py @@ -32,6 +32,16 @@ def test_plot_batch(self): z = torch.randn(size=(2, num_classes, 256, 256)).softmax(dim=-3) self.assertNoError(lambda: viz.plot_batch(x, y, z=z)) + # w/ z, w/o y + viz = SemanticSegmentationVisualizer( + class_names=['bg', 'fg'], + channel_display_groups=dict(RGB=[0, 1, 2], IR=[3])) + num_classes = 2 + x = torch.randn(size=(2, 4, 256, 256)) + y = None + z = torch.randn(size=(2, num_classes, 256, 256)).softmax(dim=-3) + self.assertNoError(lambda: viz.plot_batch(x, y, z=z)) + def test_plot_batch_temporal(self): # w/o z viz = SemanticSegmentationVisualizer( @@ -54,3 +64,13 @@ def test_plot_batch_temporal(self): self.assertNoError(lambda: viz.plot_batch(x, y, z=z)) # w/ z, batch size = 1 self.assertNoError(lambda: viz.plot_batch(x[[0]], y[[0]])) + + # w/ z, w/o y + viz = SemanticSegmentationVisualizer( + class_names=['bg', 'fg'], + channel_display_groups=dict(RGB=[0, 1, 2], IR=[3])) + num_classes = 2 + x = torch.randn(size=(2, 3, 4, 256, 256)) + y = None + z = torch.randn(size=(2, num_classes, 256, 256)).softmax(dim=-3) + self.assertNoError(lambda: viz.plot_batch(x, y, z=z))