Skip to content

Commit

Permalink
add burn scars and multi temporal crop examples
Browse files Browse the repository at this point in the history
Signed-off-by: Carlos Gomes <[email protected]>
  • Loading branch information
CarlosGomes98 committed May 30, 2024
1 parent c4afa03 commit 6fc5e62
Show file tree
Hide file tree
Showing 2 changed files with 300 additions and 0 deletions.
131 changes: 131 additions & 0 deletions examples/confs/burn_scars.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# lightning.pytorch==2.1.1
seed_everything: 0
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
precision: 16-mixed
logger:
class_path: TensorBoardLogger
init_args:
save_dir: <path>
name: fire_scars
callbacks:
- class_path: RichProgressBar
- class_path: LearningRateMonitor
init_args:
logging_interval: epoch
- class_path: EarlyStopping
init_args:
monitor: val/loss
patience: 40

max_epochs: 200
check_val_every_n_epoch: 1
log_every_n_steps: 50
enable_checkpointing: true
default_root_dir: <path>

# dataset available: https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars
data:
class_path: GenericNonGeoSegmentationDataModule
init_args:
batch_size: 4
num_workers: 8
dataset_bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
output_bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
rgb_indices:
- 0
- 1
- 2
train_transform:
- class_path: albumentations.RandomCrop
init_args:
height: 224
width: 224
- class_path: albumentations.HorizontalFlip
init_args:
p: 0.5
- class_path: ToTensorV2
no_data_replace: 0
no_label_replace: -1
train_data_root: <data_path>/training
train_label_data_root: <data_path>/training
val_data_root: <data_path>/validation
val_label_data_root: <data_path>/validation
test_data_root: <data_path>/validation
test_label_data_root: <data_path>/validation
img_grep: "*_merged.tif"
label_grep: "*.mask.tif"
means:
- 0.033349706741586264
- 0.05701185520536176
- 0.05889748132001316
- 0.2323245113436119
- 0.1972854853760658
- 0.11944914225186566
stds:
- 0.02269135568823774
- 0.026807560223070237
- 0.04004109844362779
- 0.07791732423672691
- 0.08708738838140137
- 0.07241979477437814
num_classes: 2

model:
class_path: terratorch.tasks.SemanticSegmentationTask
init_args:
model_args:
decoder: FCNDecoder
pretrained: true
backbone: prithvi_vit_100
decoder_channels: 256
in_channels: 6
bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
num_frames: 1
num_classes: 2
head_dropout: 0.1
decoder_num_convs: 4
head_channel_list:
- 256
loss: dice
plot_on_val: 10
ignore_index: -1
freeze_backbone: false
freeze_decoder: false
model_factory: PrithviModelFactory
tiled_inference_parameters:
h_crop: 512
h_stride: 496
w_crop: 512
w_stride: 496
average_patches: true
optimizer:
class_path: torch.optim.Adam
init_args:
lr: 1.5e-5
weight_decay: 0.05
lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: val/loss
169 changes: 169 additions & 0 deletions examples/confs/multi_temporal_crop.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# lightning.pytorch==2.1.1
seed_everything: 0
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
precision: 16-mixed
logger:
class_path: TensorBoardLogger
init_args:
save_dir: <path>
name: replicate
callbacks:
- class_path: RichProgressBar
- class_path: LearningRateMonitor
init_args:
logging_interval: epoch

max_epochs: 200
check_val_every_n_epoch: 1
log_every_n_steps: 50
enable_checkpointing: true
default_root_dir: <path>

# data available at: https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification
data:
class_path: GenericNonGeoSegmentationDataModule
init_args:
batch_size: 8
num_workers: 12
train_transform:
- class_path: FlattenTemporalIntoChannels
- class_path: albumentations.Flip
- class_path: ToTensorV2
- class_path: UnflattenTemporalFromChannels
init_args:
n_timesteps: 3

dataset_bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
output_bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
rgb_indices:
- 2
- 1
- 0
reduce_zero_label: True
expand_temporal_dimension: True
train_data_root: <data_path>/training_chips
train_label_data_root: <data_path>/training_chips
val_data_root: <data_path>/validation_chips
val_label_data_root: <data_path>/validation_chips
test_data_root: <data_path>/validation_chips
test_label_data_root: <data_path>/validation_chips
train_split: <data_path>/training_chips/training_data.txt
test_split: <data_path>/validation_chips/validation_data.txt
val_split: <data_path>/validation_chips/validation_data.txt
img_grep: "*_merged.tif"
label_grep: "*.mask.tif"
means:
- 494.905781
- 815.239594
- 924.335066
- 2968.881459
- 2634.621962
- 1739.579917
stds:
- 284.925432
- 357.84876
- 575.566823
- 896.601013
- 951.900334
- 921.407808
num_classes: 13

model:
class_path: terratorch.tasks.SemanticSegmentationTask
init_args:
model_args:
decoder: FCNDecoder
pretrained: true
backbone: prithvi_vit_100
in_channels: 6
rescale: False
bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
num_frames: 3
num_classes: 13
head_dropout: 0.1
decoder_channels: 512
head_channel_list:
- 128
- 64
loss: ce
class_names:
- Natural Vegetation
- Forest
- Corn
- Soybeans
- Wetlands
- Developed/Barren
- Open Water
- Winter Wheat
- Alfalfa
- Fallow/Idle Cropland
- Cotton
- Sorghum
- Other
# aux_heads:
# - name: aux_head
# decoder: FCNDecoder
# decoder_args:
# decoder_channels: 256
# decoder_in_index: 2
# decoder_num_convs: 2
# head_channel_list:
# - 64
# aux_loss:
# aux_head: 1.0
class_weights:
- 0.386375
- 0.661126
- 0.548184
- 0.640482
- 0.876862
- 0.925186
- 3.249462
- 1.542289
- 2.175141
- 2.272419
- 3.062762
- 3.626097
- 1.198702

ignore_index: -1
freeze_backbone: false
freeze_decoder: false
model_factory: PrithviModelFactory
tiled_inference_parameters:
h_crop: 224
h_stride: 196
w_crop: 224
w_stride: 196
average_patches: true
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 1.5e-5
weight_decay: 0.05
lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: val/loss

0 comments on commit 6fc5e62

Please sign in to comment.