-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
207 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,207 @@ | ||
model: | ||
class_path: rslearn.train.lightning_module.RslearnLightningModule | ||
init_args: | ||
model: | ||
class_path: rslearn.models.multitask.MultiTaskModel | ||
init_args: | ||
encoder: | ||
- class_path: rslearn.models.simple_time_series.SimpleTimeSeries | ||
init_args: | ||
encoder: | ||
class_path: rslearn.models.swin.Swin | ||
init_args: | ||
pretrained: true | ||
input_channels: 9 | ||
output_layers: [1, 3, 5, 7] | ||
image_channels: 9 | ||
decoders: | ||
segment: | ||
- class_path: rslearn.models.unet.UNetDecoder | ||
init_args: | ||
in_channels: [[4, 128], [8, 256], [16, 512], [32, 1024]] | ||
out_channels: 2 | ||
conv_layers_per_resolution: 2 | ||
- class_path: rslearn.train.tasks.segmentation.SegmentationHead | ||
lr: 0.00002 | ||
plateau: true | ||
plateau_factor: 0.2 | ||
plateau_patience: 2 | ||
plateau_min_lr: 0 | ||
plateau_cooldown: 10 | ||
restore_config: | ||
restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth | ||
remap_prefixes: | ||
- ["backbone.backbone.backbone.", "encoder.0.encoder.model."] | ||
data: | ||
class_path: rslearn.train.data_module.RslearnDataModule | ||
init_args: | ||
path: gs://rslearn-eai/datasets/solar_farm/dataset_v1/20250108/ | ||
inputs: | ||
image1: | ||
data_type: "raster" | ||
layers: ["sentinel2"] | ||
bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] | ||
passthrough: true | ||
dtype: FLOAT32 | ||
image2: | ||
data_type: "raster" | ||
layers: ["sentinel2.1"] | ||
bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] | ||
passthrough: true | ||
dtype: FLOAT32 | ||
image3: | ||
data_type: "raster" | ||
layers: ["sentinel2.2"] | ||
bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] | ||
passthrough: true | ||
dtype: FLOAT32 | ||
image4: | ||
data_type: "raster" | ||
layers: ["sentinel2.3"] | ||
bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"] | ||
passthrough: true | ||
dtype: FLOAT32 | ||
mask: | ||
data_type: "raster" | ||
layers: ["mask"] | ||
bands: ["mask"] | ||
passthrough: true | ||
dtype: INT32 | ||
is_target: true | ||
targets: | ||
data_type: "raster" | ||
layers: ["label_raster"] | ||
bands: ["label"] | ||
dtype: INT32 | ||
is_target: true | ||
task: | ||
class_path: rslearn.train.tasks.multi_task.MultiTask | ||
init_args: | ||
tasks: | ||
segment: | ||
class_path: rslearn.train.tasks.segmentation.SegmentationTask | ||
init_args: | ||
num_classes: 2 | ||
metric_kwargs: | ||
average: "micro" | ||
remap_values: [[0, 1], [0, 255]] | ||
input_mapping: | ||
segment: | ||
targets: "targets" | ||
batch_size: 8 | ||
num_workers: 32 | ||
default_config: | ||
transforms: | ||
- class_path: rslearn.train.transforms.normalize.Normalize | ||
init_args: | ||
mean: 0 | ||
std: 3000 | ||
valid_range: [0, 1] | ||
bands: [0, 1, 2] | ||
selectors: ["image1", "image2", "image3", "image4"] | ||
- class_path: rslearn.train.transforms.normalize.Normalize | ||
init_args: | ||
mean: 0 | ||
std: 8160 | ||
valid_range: [0, 1] | ||
bands: [3, 4, 5, 6, 7, 8] | ||
selectors: ["image1", "image2", "image3", "image4"] | ||
- class_path: rslearn.train.transforms.concatenate.Concatenate | ||
init_args: | ||
selections: | ||
image1: [] | ||
image2: [] | ||
image3: [] | ||
image4: [] | ||
output_selector: image | ||
- class_path: rslp.transforms.mask.Mask | ||
train_config: | ||
patch_size: 256 | ||
transforms: | ||
- class_path: rslearn.train.transforms.normalize.Normalize | ||
init_args: | ||
mean: 0 | ||
std: 3000 | ||
valid_range: [0, 1] | ||
bands: [0, 1, 2] | ||
selectors: ["image1", "image2", "image3", "image4"] | ||
- class_path: rslearn.train.transforms.normalize.Normalize | ||
init_args: | ||
mean: 0 | ||
std: 8160 | ||
valid_range: [0, 1] | ||
bands: [3, 4, 5, 6, 7, 8] | ||
selectors: ["image1", "image2", "image3", "image4"] | ||
- class_path: rslearn.train.transforms.concatenate.Concatenate | ||
init_args: | ||
selections: | ||
image1: [] | ||
image2: [] | ||
image3: [] | ||
image4: [] | ||
output_selector: image | ||
- class_path: rslp.transforms.mask.Mask | ||
- class_path: rslearn.train.transforms.flip.Flip | ||
init_args: | ||
image_selectors: ["image", "target/segment/classes", "target/segment/valid"] | ||
tags: | ||
split: train | ||
val_config: | ||
patch_size: 256 | ||
tags: | ||
split: val | ||
test_config: | ||
patch_size: 256 | ||
tags: | ||
split: val | ||
predict_config: | ||
groups: ["predict"] | ||
load_all_patches: true | ||
skip_targets: true | ||
patch_size: 512 | ||
transforms: | ||
- class_path: rslearn.train.transforms.normalize.Normalize | ||
init_args: | ||
mean: 0 | ||
std: 3000 | ||
valid_range: [0, 1] | ||
bands: [0, 1, 2] | ||
selectors: ["image1", "image2", "image3", "image4"] | ||
- class_path: rslearn.train.transforms.normalize.Normalize | ||
init_args: | ||
mean: 0 | ||
std: 8160 | ||
valid_range: [0, 1] | ||
bands: [3, 4, 5, 6, 7, 8] | ||
selectors: ["image1", "image2", "image3", "image4"] | ||
- class_path: rslearn.train.transforms.concatenate.Concatenate | ||
init_args: | ||
selections: | ||
image1: [] | ||
image2: [] | ||
image3: [] | ||
image4: [] | ||
output_selector: image | ||
trainer: | ||
max_epochs: 500 | ||
callbacks: | ||
- class_path: lightning.pytorch.callbacks.LearningRateMonitor | ||
init_args: | ||
logging_interval: "epoch" | ||
- class_path: rslearn.train.prediction_writer.RslearnWriter | ||
init_args: | ||
path: gs://rslearn-eai/datasets/solar_farm/dataset_v1/20250108/ | ||
output_layer: output | ||
selector: ["detect"] | ||
- class_path: lightning.pytorch.callbacks.ModelCheckpoint | ||
init_args: | ||
save_top_k: 1 | ||
save_last: true | ||
monitor: val_segment/accuracy | ||
mode: max | ||
- class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze | ||
init_args: | ||
module_selector: ["model", "encoder", 0, "encoder", "model"] | ||
unfreeze_at_epoch: 2 | ||
rslp_project: satlas_solar_farm | ||
rslp_experiment: data_20250108_satlaspretrain_patch256_00 |