diff --git a/notebooks/Unit 3 - Vision Transformers/fine-tuning-multilabel-image-classification.ipynb b/notebooks/Unit 3 - Vision Transformers/fine-tuning-multilabel-image-classification.ipynb index 387330067..c220fb6b1 100644 --- a/notebooks/Unit 3 - Vision Transformers/fine-tuning-multilabel-image-classification.ipynb +++ b/notebooks/Unit 3 - Vision Transformers/fine-tuning-multilabel-image-classification.ipynb @@ -579,29 +579,49 @@ "outputs": [], "source": [ "def train_transforms(batch):\n", - " # convert all images in batch to RGB to avoid grayscale or transparent images\n", - " batch['image'] = [x.convert('RGB') for x in batch['image']]\n", - " # apply torchvision.transforms per sample in the batch\n", - " inputs = [train_tfms(x) for x in batch['image']]\n", - " batch['pixel_values'] = inputs\n", - " \n", - " # one-hot encoding the labels\n", - " labels = torch.tensor(batch['classes'])\n", - " batch['labels'] = nn.functional.one_hot(labels,num_classes=20).sum(dim=1)\n", - " \n", + " # Convert all images to RGB format\n", + " if isinstance(batch['image'], list):\n", + " # Batch processing\n", + " batch['image'] = [x.convert('RGB') for x in batch['image']]\n", + " inputs = [train_tfms(x) for x in batch['image']]\n", + " batch['pixel_values'] = torch.stack(inputs) # Stack tensor outputs\n", + " else:\n", + " # Single sample processing\n", + " batch['image'] = batch['image'].convert('RGB')\n", + " batch['pixel_values'] = train_tfms(batch['image'])\n", + "\n", + " # One-hot encode the multilabels\n", + " all_labels = [torch.tensor(labels) for labels in batch['classes']]\n", + "\n", + " # Create one-hot encoding for each image's classes\n", + " one_hot_labels = [nn.functional.one_hot(label, num_classes=20).sum(dim=0) for label in all_labels]\n", + "\n", + " # Stack them into a batch\n", + " batch['labels'] = torch.stack(one_hot_labels)\n", + "\n", " return batch\n", "\n", "def valid_transforms(batch):\n", - " # convert all images in batch to RGB to avoid grayscale or transparent images\n", - " batch['image'] = [x.convert('RGB') for x in batch['image']]\n", - " # apply torchvision.transforms per sample in the batch\n", - " inputs = [valid_tfms(x) for x in batch['image']]\n", - " batch['pixel_values'] = inputs\n", - " \n", - " # one-hot encoding the labels\n", - " labels = torch.tensor(batch['classes'])\n", - " batch['labels'] = nn.functional.one_hot(labels,num_classes=20).sum(dim=1)\n", - " \n", + " # Convert all images to RGB format\n", + " if isinstance(batch['image'], list):\n", + " # Batch processing\n", + " batch['image'] = [x.convert('RGB') for x in batch['image']]\n", + " inputs = [train_tfms(x) for x in batch['image']]\n", + " batch['pixel_values'] = torch.stack(inputs) # Stack tensor outputs\n", + " else:\n", + " # Single sample processing\n", + " batch['image'] = batch['image'].convert('RGB')\n", + " batch['pixel_values'] = train_tfms(batch['image'])\n", + "\n", + " # One-hot encode the multilabels\n", + " all_labels = [torch.tensor(labels) for labels in batch['classes']]\n", + "\n", + " # Create one-hot encoding for each image's classes\n", + " one_hot_labels = [nn.functional.one_hot(label, num_classes=20).sum(dim=0) for label in all_labels]\n", + "\n", + " # Stack them into a batch\n", + " batch['labels'] = torch.stack(one_hot_labels)\n", + "\n", " return batch" ] },