From 6fc5e6227ee30167efd6c4dc059118989345cfe3 Mon Sep 17 00:00:00 2001 From: Carlos Gomes Date: Thu, 30 May 2024 15:01:45 +0200 Subject: [PATCH] add burn scars and multi temporal crop examples Signed-off-by: Carlos Gomes --- examples/confs/burn_scars.yaml | 131 ++++++++++++++++++ examples/confs/multi_temporal_crop.yaml | 169 ++++++++++++++++++++++++ 2 files changed, 300 insertions(+) create mode 100644 examples/confs/burn_scars.yaml create mode 100644 examples/confs/multi_temporal_crop.yaml diff --git a/examples/confs/burn_scars.yaml b/examples/confs/burn_scars.yaml new file mode 100644 index 00000000..94144bd0 --- /dev/null +++ b/examples/confs/burn_scars.yaml @@ -0,0 +1,131 @@ +# lightning.pytorch==2.1.1 +seed_everything: 0 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: 16-mixed + logger: + class_path: TensorBoardLogger + init_args: + save_dir: + name: fire_scars + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + - class_path: EarlyStopping + init_args: + monitor: val/loss + patience: 40 + + max_epochs: 200 + check_val_every_n_epoch: 1 + log_every_n_steps: 50 + enable_checkpointing: true + default_root_dir: + +# dataset available: https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars +data: + class_path: GenericNonGeoSegmentationDataModule + init_args: + batch_size: 4 + num_workers: 8 + dataset_bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + output_bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + rgb_indices: + - 0 + - 1 + - 2 + train_transform: + - class_path: albumentations.RandomCrop + init_args: + height: 224 + width: 224 + - class_path: albumentations.HorizontalFlip + init_args: + p: 0.5 + - class_path: ToTensorV2 + no_data_replace: 0 + no_label_replace: -1 + train_data_root: /training + train_label_data_root: /training + val_data_root: /validation + val_label_data_root: /validation + test_data_root: /validation + test_label_data_root: /validation + img_grep: "*_merged.tif" + label_grep: "*.mask.tif" + means: + - 0.033349706741586264 + - 0.05701185520536176 + - 0.05889748132001316 + - 0.2323245113436119 + - 0.1972854853760658 + - 0.11944914225186566 + stds: + - 0.02269135568823774 + - 0.026807560223070237 + - 0.04004109844362779 + - 0.07791732423672691 + - 0.08708738838140137 + - 0.07241979477437814 + num_classes: 2 + +model: + class_path: terratorch.tasks.SemanticSegmentationTask + init_args: + model_args: + decoder: FCNDecoder + pretrained: true + backbone: prithvi_vit_100 + decoder_channels: 256 + in_channels: 6 + bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + num_frames: 1 + num_classes: 2 + head_dropout: 0.1 + decoder_num_convs: 4 + head_channel_list: + - 256 + loss: dice + plot_on_val: 10 + ignore_index: -1 + freeze_backbone: false + freeze_decoder: false + model_factory: PrithviModelFactory + tiled_inference_parameters: + h_crop: 512 + h_stride: 496 + w_crop: 512 + w_stride: 496 + average_patches: true +optimizer: + class_path: torch.optim.Adam + init_args: + lr: 1.5e-5 + weight_decay: 0.05 +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss diff --git a/examples/confs/multi_temporal_crop.yaml b/examples/confs/multi_temporal_crop.yaml new file mode 100644 index 00000000..e8390f8b --- /dev/null +++ b/examples/confs/multi_temporal_crop.yaml @@ -0,0 +1,169 @@ +# lightning.pytorch==2.1.1 +seed_everything: 0 +trainer: + accelerator: auto + strategy: auto + devices: auto + num_nodes: 1 + precision: 16-mixed + logger: + class_path: TensorBoardLogger + init_args: + save_dir: + name: replicate + callbacks: + - class_path: RichProgressBar + - class_path: LearningRateMonitor + init_args: + logging_interval: epoch + + max_epochs: 200 + check_val_every_n_epoch: 1 + log_every_n_steps: 50 + enable_checkpointing: true + default_root_dir: + +# data available at: https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification +data: + class_path: GenericNonGeoSegmentationDataModule + init_args: + batch_size: 8 + num_workers: 12 + train_transform: + - class_path: FlattenTemporalIntoChannels + - class_path: albumentations.Flip + - class_path: ToTensorV2 + - class_path: UnflattenTemporalFromChannels + init_args: + n_timesteps: 3 + + dataset_bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + output_bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + rgb_indices: + - 2 + - 1 + - 0 + reduce_zero_label: True + expand_temporal_dimension: True + train_data_root: /training_chips + train_label_data_root: /training_chips + val_data_root: /validation_chips + val_label_data_root: /validation_chips + test_data_root: /validation_chips + test_label_data_root: /validation_chips + train_split: /training_chips/training_data.txt + test_split: /validation_chips/validation_data.txt + val_split: /validation_chips/validation_data.txt + img_grep: "*_merged.tif" + label_grep: "*.mask.tif" + means: + - 494.905781 + - 815.239594 + - 924.335066 + - 2968.881459 + - 2634.621962 + - 1739.579917 + stds: + - 284.925432 + - 357.84876 + - 575.566823 + - 896.601013 + - 951.900334 + - 921.407808 + num_classes: 13 + +model: + class_path: terratorch.tasks.SemanticSegmentationTask + init_args: + model_args: + decoder: FCNDecoder + pretrained: true + backbone: prithvi_vit_100 + in_channels: 6 + rescale: False + bands: + - BLUE + - GREEN + - RED + - NIR_NARROW + - SWIR_1 + - SWIR_2 + num_frames: 3 + num_classes: 13 + head_dropout: 0.1 + decoder_channels: 512 + head_channel_list: + - 128 + - 64 + loss: ce + class_names: + - Natural Vegetation + - Forest + - Corn + - Soybeans + - Wetlands + - Developed/Barren + - Open Water + - Winter Wheat + - Alfalfa + - Fallow/Idle Cropland + - Cotton + - Sorghum + - Other + # aux_heads: + # - name: aux_head + # decoder: FCNDecoder + # decoder_args: + # decoder_channels: 256 + # decoder_in_index: 2 + # decoder_num_convs: 2 + # head_channel_list: + # - 64 + # aux_loss: + # aux_head: 1.0 + class_weights: + - 0.386375 + - 0.661126 + - 0.548184 + - 0.640482 + - 0.876862 + - 0.925186 + - 3.249462 + - 1.542289 + - 2.175141 + - 2.272419 + - 3.062762 + - 3.626097 + - 1.198702 + + ignore_index: -1 + freeze_backbone: false + freeze_decoder: false + model_factory: PrithviModelFactory + tiled_inference_parameters: + h_crop: 224 + h_stride: 196 + w_crop: 224 + w_stride: 196 + average_patches: true +optimizer: + class_path: torch.optim.AdamW + init_args: + lr: 1.5e-5 + weight_decay: 0.05 +lr_scheduler: + class_path: ReduceLROnPlateau + init_args: + monitor: val/loss