Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

collate_fn is not handled properly in TorchNonGeoDataModule #325

Open
takaomoriyama opened this issue Dec 12, 2024 · 5 comments
Open

collate_fn is not handled properly in TorchNonGeoDataModule #325

takaomoriyama opened this issue Dec 12, 2024 · 5 comments

Comments

@takaomoriyama
Copy link
Member

takaomoriyama commented Dec 12, 2024

Describe the issue
When terratorch.datamodules.TorchNonGeoDataModule creates an instance of given Torchgeo datamodule which specifies it's collation function to collate_fn parameter, the specified parameter is ignored and initialized with NonGeoDataModule's default collation function default_collate() defined in torch.utils.data (https://github.com/microsoft/torchgeo/blob/801e94746f9fe9d1b84b4220c3eb2fc8248a59a8/torchgeo/datamodules/geo.py#L387)
.

For example, torchgeo.datamodules.VHR10DataModule wants to use it's own collate_fn_detection() (see https://github.com/microsoft/torchgeo/blob/main/torchgeo/datamodules/vhr10.py#L53) to handle unequal sized sample data.

To Reproduce (optional, but appreciated)

$ pip install pycocotools
$ pyton collate_fn_test.py
$ cat colllate_fn_test.py
from terratorch.datamodules import TorchNonGeoDataModule
from torchgeo.datamodules import VHR10DataModule
from torchgeo.trainers import ObjectDetectionTask
from lightning.pytorch import Trainer

datamodule = TorchNonGeoDataModule(
    VHR10DataModule,
    batch_size=16,
    root='data/VHR10',
    download=True,
)
print(f'{datamodule.collate_fn=}')

class VariableSizeInputObjectDetectionTask(ObjectDetectionTask):
    # Define the training step
    def training_step(self, batch, batch_idx, dataloader_idx=0):
        x = batch["image"]  # Image
        batch_size = len(x)  # Set batch size (number of images)
        y = [
            {"boxes": batch["boxes"][i], "labels": batch["labels"][i]}
            for i in range(batch_size)
        ] # Extract bounding box and label information for each image
        loss_dict = self(x, y)  # Loss
        train_loss: Tensor = sum(loss_dict.values())  # Training loss (sum of loss values)
        self.log_dict(loss_dict)  # Record loss values
        return train_loss  # Return training loss

task = VariableSizeInputObjectDetectionTask(
    model="faster-rcnn",  # Faster R-CNN model
    backbone="resnet18",  # ResNet18 neural network architecture
    weights=True,  # Use pretrained weights
    in_channels=3,  # Number of channels in the input image (RGB images)
    num_classes=11,  # Number of classes to classify (10 + background)
    trainable_layers=3,  # Number of trainable layers
    lr=1e-3,  # Learning rate
    patience=10,  # Set the number of patience iterations for early stopping
    freeze_backbone=False,  # Whether to train with the backbone network weights unfrozen
)

trainer = Trainer(
    devices=1,
    precision="16-mixed",
    max_epochs=1,
    default_root_dir='output/collate_fn_test',
    log_every_n_steps=1,
    check_val_every_n_epoch=1
)
_ = trainer.fit(model=task, datamodule=datamodule)

Screenshots or log output (optional)

datamodule.collate_fn=<function default_collate at 0x7fe30ab31800>
/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-os/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/accelerator_connector.py:512: You passed `Trainer(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on CPU. Using `precision='bf16-mixed'` instead.
Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Files already downloaded and verified
loading annotations into memory...
Done (t=0.03s)
creating index...
index created!

  | Name         | Type             | Params | Mode
----------------------------------------------------------
0 | model        | FasterRCNN       | 28.3 M | train
1 | val_metrics  | MetricCollection | 0      | train
2 | test_metrics | MetricCollection | 0      | train
----------------------------------------------------------
28.2 M    Trainable params
156 K     Non-trainable params
28.3 M    Total params
113.280   Total estimated model params size (MB)
110       Modules in train mode
0         Modules in eval mode
Sanity Checking: |          | 0/? [00:00<?, ?it/s]/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-os/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' d\
oes not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
Traceback (most recent call last):
  File "/dccstor/usgs_dem/moriyama/dev/IBM/terratorch-collate_fn/examples/confs/collate_fn_test.py", line 52, in <module>
    _ = trainer.fit(model=task, datamodule=datamodule)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-os/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
    call._call_and_handle_interrupt(
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-os/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-os/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-os/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-os/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1023, in _run_stage
    self._run_sanity_check()
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-os/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1052, in _run_sanity_check
    val_loop.run()
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-os/lib/python3.11/site-packages/lightning/pytorch/loops/utilities.py", line 178, in _decorator
    return loop_run(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-os/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 128, in run
    batch, batch_idx, dataloader_idx = next(data_fetcher)
                                       ^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-os/lib/python3.11/site-packages/lightning/pytorch/loops/fetchers.py", line 133, in __next__
    batch = super().__next__()
            ^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-os/lib/python3.11/site-packages/lightning/pytorch/loops/fetchers.py", line 60, in __next__
    batch = next(self.iterator)
            ^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-os/lib/python3.11/site-packages/lightning/pytorch/utilities/combined_loader.py", line 341, in __next__
    out = next(self._iterator)
          ^^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-os/lib/python3.11/site-packages/lightning/pytorch/utilities/combined_loader.py", line 142, in __next__
    out = next(self.iterators[0])
          ^^^^^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-os/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 631, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-os/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 675, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-os/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-os/lib/python3.11/site-packages/torch/utils/data/_utils/collate.py", line 316, in default_collate
    return collate(batch, collate_fn_map=default_collate_fn_map)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-os/lib/python3.11/site-packages/torch/utils/data/_utils/collate.py", line 154, in collate
    clone.update({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-os/lib/python3.11/site-packages/torch/utils/data/_utils/collate.py", line 154, in <dictcomp>
    clone.update({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-os/lib/python3.11/site-packages/torch/utils/data/_utils/collate.py", line 141, in collate
    return collate_fn_map[elem_type](batch, collate_fn_map=collate_fn_map)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/dccstor/usgs_dem/moriyama/.conda/envs/terratorch-os/lib/python3.11/site-packages/torch/utils/data/_utils/collate.py", line 213, in collate_tensor_fn
    return torch.stack(batch, 0, out=out)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: stack expects each tensor to be equal size, but got [3, 540, 831] at entry 0 and [3, 649, 973] at entry 1

You may see in the first line of the output that collate_fn is initialized with default_collate() function.

In this case, the error is occurring with batch['image'] data, and we may avoid this by resizing all image data to same size. However, because VHR10 dataset is for object detection task, and all label data might be in different sizes (according to the number of found objects).

@takaomoriyama
Copy link
Member Author

takaomoriyama commented Dec 12, 2024

The root cause might be that in init() of TorchNonGeoDataModule, super().__init__() is called after instance of cls (in this case VHR10DataModule) is created.

self._proxy = cls(num_workers=num_workers, **kwargs)
super().__init__(self._proxy.dataset_class) # dummy arg

super().__init__() then initializes collate_fn with its default (default_collate()), and the setter of collate_fn of TorchNonGeoDataModule will overwrite collate_fn variable of VHR10DataModules instance' .

@collate_fn.setter
def collate_fn(self, value):
self._proxy.collate_fn = value

@takaomoriyama
Copy link
Member Author

My tentative solution is to make an instance of cls object again after calling super().__init__().

   self._proxy = cls(num_workers=num_workers, **kwargs)
   super().__init__(self._proxy.dataset_class)  # dummy arg
   self._proxy = cls(num_workers=num_workers, **kwargs) # Create an object again

@takaomoriyama
Copy link
Member Author

takaomoriyama commented Dec 17, 2024

@paolofraccaro suggested an idea that instead of instantiating cls object twice, passing None to super().__init__() to avoid overwriting to self.collate_fn as follows.

   self._proxy = cls(num_workers=num_workers, **kwargs)
   super().__init__(None)  # dummy arg

This does not work because overwriting to self.collate_fn does not occur in the instance of cls (VHR10DataModule in this case), but occurs first in the instance of TorchNonGeoDataModule (actually its super class NonGeoDataModule), then it is propagated to the instance of VHR10DataModule via setter function of TorchNonGeoDataModule.

@collate_fn.setter
def collate_fn(self, value):
self._proxy.collate_fn = value

@takaomoriyama
Copy link
Member Author

takaomoriyama commented Dec 17, 2024

Another idea is to swap two lines as follows.

   super().__init__(None)  # dummy arg
   self._proxy = cls(num_workers=num_workers, **kwargs)

and also modifying the setter function as follows to avoid reference to undefined variable self._proxy.

 @collate_fn.setter 
 def collate_fn(self, value): 
    if hasattr(self, '_proxy'):
        self._proxy.collate_fn = value 

The latter method is employed in commit d51af02.

@takaomoriyama
Copy link
Member Author

takaomoriyama commented Jan 6, 2025

Found that other functions (aug, train_aug, val_aug, test_aug, and predict_aug) of TorhNonGeoDataModule require @Property an @Setter so that user specified augmentation module to be effective.
The fix is in 7198d85.
Root cause is same as collate_fn described above.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant