Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/add more examples #14

Merged
merged 2 commits into from
May 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions docs/examples.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@ For some examples of training using the existing tasks, check out the following

Under `examples/confs`

* Flood Segmentation with ViT: `segmentation_config_vit.yaml`
* Flood Segmentation with ViT: `sen1floods11_vit.yaml`

* Multitemporal Crop Segmentation: `multitemporal_crop.yaml`

* Scene Classification: `eurosat.yaml`
* Burn Scar Segmentation: `burn_scars.yaml`

* Usage of an SMP backbone `geobench/segmentation/m_chesapeake_landcover_smp_resnet_unet.yaml`

* Usage of a timm backbone `geobench/classification/m_bigearthnet_timm_resnet.yaml`
* Scene Classification: `eurosat.yaml`
131 changes: 131 additions & 0 deletions examples/confs/burn_scars.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# lightning.pytorch==2.1.1
seed_everything: 0
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
precision: 16-mixed
logger:
class_path: TensorBoardLogger
init_args:
save_dir: <path>
name: fire_scars
callbacks:
- class_path: RichProgressBar
- class_path: LearningRateMonitor
init_args:
logging_interval: epoch
- class_path: EarlyStopping
init_args:
monitor: val/loss
patience: 40

max_epochs: 200
check_val_every_n_epoch: 1
log_every_n_steps: 50
enable_checkpointing: true
default_root_dir: <path>

# dataset available: https://huggingface.co/datasets/ibm-nasa-geospatial/hls_burn_scars
data:
class_path: GenericNonGeoSegmentationDataModule
init_args:
batch_size: 4
num_workers: 8
dataset_bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
output_bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
rgb_indices:
- 0
- 1
- 2
train_transform:
- class_path: albumentations.RandomCrop
init_args:
height: 224
width: 224
- class_path: albumentations.HorizontalFlip
init_args:
p: 0.5
- class_path: ToTensorV2
no_data_replace: 0
no_label_replace: -1
train_data_root: <data_path>/training
train_label_data_root: <data_path>/training
val_data_root: <data_path>/validation
val_label_data_root: <data_path>/validation
test_data_root: <data_path>/validation
test_label_data_root: <data_path>/validation
img_grep: "*_merged.tif"
label_grep: "*.mask.tif"
means:
- 0.033349706741586264
- 0.05701185520536176
- 0.05889748132001316
- 0.2323245113436119
- 0.1972854853760658
- 0.11944914225186566
stds:
- 0.02269135568823774
- 0.026807560223070237
- 0.04004109844362779
- 0.07791732423672691
- 0.08708738838140137
- 0.07241979477437814
num_classes: 2

model:
class_path: terratorch.tasks.SemanticSegmentationTask
init_args:
model_args:
decoder: FCNDecoder
pretrained: true
backbone: prithvi_vit_100
decoder_channels: 256
in_channels: 6
bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
num_frames: 1
num_classes: 2
head_dropout: 0.1
decoder_num_convs: 4
head_channel_list:
- 256
loss: dice
plot_on_val: 10
ignore_index: -1
freeze_backbone: false
freeze_decoder: false
model_factory: PrithviModelFactory
tiled_inference_parameters:
h_crop: 512
h_stride: 496
w_crop: 512
w_stride: 496
average_patches: true
optimizer:
class_path: torch.optim.Adam
init_args:
lr: 1.5e-5
weight_decay: 0.05
lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: val/loss
169 changes: 169 additions & 0 deletions examples/confs/multi_temporal_crop.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# lightning.pytorch==2.1.1
seed_everything: 0
trainer:
accelerator: auto
strategy: auto
devices: auto
num_nodes: 1
precision: 16-mixed
logger:
class_path: TensorBoardLogger
init_args:
save_dir: <path>
name: replicate
callbacks:
- class_path: RichProgressBar
- class_path: LearningRateMonitor
init_args:
logging_interval: epoch

max_epochs: 200
check_val_every_n_epoch: 1
log_every_n_steps: 50
enable_checkpointing: true
default_root_dir: <path>

# data available at: https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification
data:
class_path: GenericNonGeoSegmentationDataModule
init_args:
batch_size: 8
num_workers: 12
train_transform:
- class_path: FlattenTemporalIntoChannels
- class_path: albumentations.Flip
- class_path: ToTensorV2
- class_path: UnflattenTemporalFromChannels
init_args:
n_timesteps: 3

dataset_bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
output_bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
rgb_indices:
- 2
- 1
- 0
reduce_zero_label: True
expand_temporal_dimension: True
train_data_root: <data_path>/training_chips
train_label_data_root: <data_path>/training_chips
val_data_root: <data_path>/validation_chips
val_label_data_root: <data_path>/validation_chips
test_data_root: <data_path>/validation_chips
test_label_data_root: <data_path>/validation_chips
train_split: <data_path>/training_chips/training_data.txt
test_split: <data_path>/validation_chips/validation_data.txt
val_split: <data_path>/validation_chips/validation_data.txt
img_grep: "*_merged.tif"
label_grep: "*.mask.tif"
means:
- 494.905781
- 815.239594
- 924.335066
- 2968.881459
- 2634.621962
- 1739.579917
stds:
- 284.925432
- 357.84876
- 575.566823
- 896.601013
- 951.900334
- 921.407808
num_classes: 13

model:
class_path: terratorch.tasks.SemanticSegmentationTask
init_args:
model_args:
decoder: FCNDecoder
pretrained: true
backbone: prithvi_vit_100
in_channels: 6
rescale: False
bands:
- BLUE
- GREEN
- RED
- NIR_NARROW
- SWIR_1
- SWIR_2
num_frames: 3
num_classes: 13
head_dropout: 0.1
decoder_channels: 512
head_channel_list:
- 128
- 64
loss: ce
class_names:
- Natural Vegetation
- Forest
- Corn
- Soybeans
- Wetlands
- Developed/Barren
- Open Water
- Winter Wheat
- Alfalfa
- Fallow/Idle Cropland
- Cotton
- Sorghum
- Other
# aux_heads:
# - name: aux_head
# decoder: FCNDecoder
# decoder_args:
# decoder_channels: 256
# decoder_in_index: 2
# decoder_num_convs: 2
# head_channel_list:
# - 64
# aux_loss:
# aux_head: 1.0
class_weights:
- 0.386375
- 0.661126
- 0.548184
- 0.640482
- 0.876862
- 0.925186
- 3.249462
- 1.542289
- 2.175141
- 2.272419
- 3.062762
- 3.626097
- 1.198702

ignore_index: -1
freeze_backbone: false
freeze_decoder: false
model_factory: PrithviModelFactory
tiled_inference_parameters:
h_crop: 224
h_stride: 196
w_crop: 224
w_stride: 196
average_patches: true
optimizer:
class_path: torch.optim.AdamW
init_args:
lr: 1.5e-5
weight_decay: 0.05
lr_scheduler:
class_path: ReduceLROnPlateau
init_args:
monitor: val/loss
Loading