Skip to content

Commit

Permalink
add solar farm config
Browse files Browse the repository at this point in the history
  • Loading branch information
uakfdotb committed Jan 31, 2025
1 parent 3c9083c commit 1efd544
Showing 1 changed file with 207 additions and 0 deletions.
207 changes: 207 additions & 0 deletions data/satlas/solar_farm/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
model:
class_path: rslearn.train.lightning_module.RslearnLightningModule
init_args:
model:
class_path: rslearn.models.multitask.MultiTaskModel
init_args:
encoder:
- class_path: rslearn.models.simple_time_series.SimpleTimeSeries
init_args:
encoder:
class_path: rslearn.models.swin.Swin
init_args:
pretrained: true
input_channels: 9
output_layers: [1, 3, 5, 7]
image_channels: 9
decoders:
segment:
- class_path: rslearn.models.unet.UNetDecoder
init_args:
in_channels: [[4, 128], [8, 256], [16, 512], [32, 1024]]
out_channels: 2
conv_layers_per_resolution: 2
- class_path: rslearn.train.tasks.segmentation.SegmentationHead
lr: 0.00002
plateau: true
plateau_factor: 0.2
plateau_patience: 2
plateau_min_lr: 0
plateau_cooldown: 10
restore_config:
restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth
remap_prefixes:
- ["backbone.backbone.backbone.", "encoder.0.encoder.model."]
data:
class_path: rslearn.train.data_module.RslearnDataModule
init_args:
path: gs://rslearn-eai/datasets/solar_farm/dataset_v1/20250108/
inputs:
image1:
data_type: "raster"
layers: ["sentinel2"]
bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"]
passthrough: true
dtype: FLOAT32
image2:
data_type: "raster"
layers: ["sentinel2.1"]
bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"]
passthrough: true
dtype: FLOAT32
image3:
data_type: "raster"
layers: ["sentinel2.2"]
bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"]
passthrough: true
dtype: FLOAT32
image4:
data_type: "raster"
layers: ["sentinel2.3"]
bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"]
passthrough: true
dtype: FLOAT32
mask:
data_type: "raster"
layers: ["mask"]
bands: ["mask"]
passthrough: true
dtype: INT32
is_target: true
targets:
data_type: "raster"
layers: ["label_raster"]
bands: ["label"]
dtype: INT32
is_target: true
task:
class_path: rslearn.train.tasks.multi_task.MultiTask
init_args:
tasks:
segment:
class_path: rslearn.train.tasks.segmentation.SegmentationTask
init_args:
num_classes: 2
metric_kwargs:
average: "micro"
remap_values: [[0, 1], [0, 255]]
input_mapping:
segment:
targets: "targets"
batch_size: 8
num_workers: 32
default_config:
transforms:
- class_path: rslearn.train.transforms.normalize.Normalize
init_args:
mean: 0
std: 3000
valid_range: [0, 1]
bands: [0, 1, 2]
selectors: ["image1", "image2", "image3", "image4"]
- class_path: rslearn.train.transforms.normalize.Normalize
init_args:
mean: 0
std: 8160
valid_range: [0, 1]
bands: [3, 4, 5, 6, 7, 8]
selectors: ["image1", "image2", "image3", "image4"]
- class_path: rslearn.train.transforms.concatenate.Concatenate
init_args:
selections:
image1: []
image2: []
image3: []
image4: []
output_selector: image
- class_path: rslp.transforms.mask.Mask
train_config:
patch_size: 256
transforms:
- class_path: rslearn.train.transforms.normalize.Normalize
init_args:
mean: 0
std: 3000
valid_range: [0, 1]
bands: [0, 1, 2]
selectors: ["image1", "image2", "image3", "image4"]
- class_path: rslearn.train.transforms.normalize.Normalize
init_args:
mean: 0
std: 8160
valid_range: [0, 1]
bands: [3, 4, 5, 6, 7, 8]
selectors: ["image1", "image2", "image3", "image4"]
- class_path: rslearn.train.transforms.concatenate.Concatenate
init_args:
selections:
image1: []
image2: []
image3: []
image4: []
output_selector: image
- class_path: rslp.transforms.mask.Mask
- class_path: rslearn.train.transforms.flip.Flip
init_args:
image_selectors: ["image", "target/segment/classes", "target/segment/valid"]
tags:
split: train
val_config:
patch_size: 256
tags:
split: val
test_config:
patch_size: 256
tags:
split: val
predict_config:
groups: ["predict"]
load_all_patches: true
skip_targets: true
patch_size: 512
transforms:
- class_path: rslearn.train.transforms.normalize.Normalize
init_args:
mean: 0
std: 3000
valid_range: [0, 1]
bands: [0, 1, 2]
selectors: ["image1", "image2", "image3", "image4"]
- class_path: rslearn.train.transforms.normalize.Normalize
init_args:
mean: 0
std: 8160
valid_range: [0, 1]
bands: [3, 4, 5, 6, 7, 8]
selectors: ["image1", "image2", "image3", "image4"]
- class_path: rslearn.train.transforms.concatenate.Concatenate
init_args:
selections:
image1: []
image2: []
image3: []
image4: []
output_selector: image
trainer:
max_epochs: 500
callbacks:
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: "epoch"
- class_path: rslearn.train.prediction_writer.RslearnWriter
init_args:
path: gs://rslearn-eai/datasets/solar_farm/dataset_v1/20250108/
output_layer: output
selector: ["detect"]
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
save_top_k: 1
save_last: true
monitor: val_segment/accuracy
mode: max
- class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze
init_args:
module_selector: ["model", "encoder", 0, "encoder", "model"]
unfreeze_at_epoch: 2
rslp_project: satlas_solar_farm
rslp_experiment: data_20250108_satlaspretrain_patch256_00

0 comments on commit 1efd544

Please sign in to comment.