Skip to content

Commit

Permalink
sync
Browse files Browse the repository at this point in the history
  • Loading branch information
favyen2 committed Dec 12, 2024
1 parent 07565da commit 7085a2f
Show file tree
Hide file tree
Showing 19 changed files with 2,115 additions and 226 deletions.
36 changes: 33 additions & 3 deletions convert_satlas_webmercator_to_rslearn/wind_turbine/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"output": {
"type": "vector"
},
"sentinel2": {
"sentinel2_a": {
"band_sets": [
{
"bands": [
Expand Down Expand Up @@ -68,7 +68,8 @@
},
"type": "raster"
},
"sentinel2.1": {
"sentinel2_b": {
"alias": "sentinel2",
"band_sets": [
{
"bands": [
Expand Down Expand Up @@ -101,9 +102,24 @@
"zoom_offset": -2
}
],
"data_source": {
"harmonize": true,
"index_cache_dir": "cache/sentinel2",
"max_time_delta": "1d",
"modality": "L1C",
"name": "rslearn.data_sources.gcp_public_data.Sentinel2",
"query_config": {
"max_matches": 2,
"space_mode": "CONTAINS"
},
"sort_by": "cloud_cover",
"time_offset": "-90d",
"use_rtree_index": false
},
"type": "raster"
},
"sentinel2.2": {
"sentinel2_c": {
"alias": "sentinel2",
"band_sets": [
{
"bands": [
Expand Down Expand Up @@ -136,6 +152,20 @@
"zoom_offset": -2
}
],
"data_source": {
"harmonize": true,
"index_cache_dir": "cache/sentinel2",
"max_time_delta": "1d",
"modality": "L1C",
"name": "rslearn.data_sources.gcp_public_data.Sentinel2",
"query_config": {
"max_matches": 2,
"space_mode": "CONTAINS"
},
"sort_by": "cloud_cover",
"time_offset": "-180d",
"use_rtree_index": false
},
"type": "raster"
}
},
Expand Down
2 changes: 1 addition & 1 deletion data/satlas/marine_infra/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -216,4 +216,4 @@ trainer:
module_selector: ["model", "encoder", 0, "encoder", "model"]
unfreeze_at_epoch: 2
rslp_project: satlas_marine_infra
rslp_experiment: data_20241030_satlaspretrainold_patch512_00
rslp_experiment: data_20241030_run_20241210_00
210 changes: 210 additions & 0 deletions data/satlas/marine_infra/config_20241002.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
model:
class_path: rslearn.train.lightning_module.RslearnLightningModule
init_args:
model:
class_path: rslearn.models.multitask.MultiTaskModel
init_args:
encoder:
- class_path: rslearn.models.simple_time_series.SimpleTimeSeries
init_args:
encoder:
class_path: rslearn.models.swin.Swin
init_args:
pretrained: true
input_channels: 9
output_layers: [1, 3, 5, 7]
image_channels: 9
- class_path: rslearn.models.fpn.Fpn
init_args:
in_channels: [128, 256, 512, 1024]
out_channels: 128
decoders:
detect:
- class_path: rslearn.models.faster_rcnn.FasterRCNN
init_args:
downsample_factors: [4, 8, 16, 32]
num_channels: 128
num_classes: 3
anchor_sizes: [[32], [64], [128], [256]]
lr: 0.00002
plateau: true
plateau_factor: 0.2
plateau_patience: 2
plateau_min_lr: 0
plateau_cooldown: 10
restore_config:
restore_path: https://ai2-public-datasets.s3.amazonaws.com/satlas/satlas-model-v1-lowres-band-multi.pth
remap_prefixes:
- ["backbone.backbone.backbone.", "encoder.0.encoder.model."]
data:
class_path: rslearn.train.data_module.RslearnDataModule
init_args:
path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/live/
inputs:
image1:
data_type: "raster"
layers: ["sentinel2"]
bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"]
passthrough: true
dtype: FLOAT32
image2:
data_type: "raster"
layers: ["sentinel2.1"]
bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"]
passthrough: true
dtype: FLOAT32
image3:
data_type: "raster"
layers: ["sentinel2.2"]
bands: ["B04", "B03", "B02", "B05", "B06", "B07", "B08", "B11", "B12"]
passthrough: true
dtype: FLOAT32
mask:
data_type: "raster"
layers: ["mask"]
bands: ["mask"]
passthrough: true
dtype: FLOAT32
is_target: true
targets:
data_type: "vector"
layers: ["label"]
is_target: true
task:
class_path: rslearn.train.tasks.multi_task.MultiTask
init_args:
tasks:
detect:
class_path: rslp.satlas.train.MarineInfraTask
init_args:
property_name: "category"
classes: ["unknown", "platform", "turbine"]
box_size: 15
remap_values: [[0, 1], [0, 255]]
exclude_by_center: true
enable_map_metric: true
enable_f1_metric: true
f1_metric_thresholds: [[0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.95], [0.1], [0.2], [0.3], [0.4], [0.5], [0.6], [0.7], [0.8], [0.9]]
skip_unknown_categories: true
f1_metric_kwargs:
cmp_mode: "distance"
cmp_threshold: 15
flatten_classes: true
input_mapping:
detect:
targets: "targets"
batch_size: 8
num_workers: 32
default_config:
transforms:
- class_path: rslearn.train.transforms.normalize.Normalize
init_args:
mean: 0
std: 3000
valid_range: [0, 1]
bands: [0, 1, 2]
selectors: ["image1", "image2", "image3"]
- class_path: rslearn.train.transforms.normalize.Normalize
init_args:
mean: 0
std: 8160
valid_range: [0, 1]
bands: [3, 4, 5, 6, 7, 8]
selectors: ["image1", "image2", "image3"]
- class_path: rslearn.train.transforms.concatenate.Concatenate
init_args:
selections:
image1: []
image2: []
image3: []
output_selector: image
- class_path: rslp.transforms.mask.Mask
train_config:
patch_size: 512
transforms:
- class_path: rslearn.train.transforms.normalize.Normalize
init_args:
mean: 0
std: 3000
valid_range: [0, 1]
bands: [0, 1, 2]
selectors: ["image1", "image2", "image3"]
- class_path: rslearn.train.transforms.normalize.Normalize
init_args:
mean: 0
std: 8160
valid_range: [0, 1]
bands: [3, 4, 5, 6, 7, 8]
selectors: ["image1", "image2", "image3"]
- class_path: rslearn.train.transforms.concatenate.Concatenate
init_args:
selections:
image1: []
image2: []
image3: []
output_selector: image
- class_path: rslp.transforms.mask.Mask
- class_path: rslearn.train.transforms.flip.Flip
init_args:
image_selectors: ["image"]
box_selectors: ["target/detect"]
tags:
split: train
val_config:
patch_size: 512
tags:
split: val
test_config:
patch_size: 512
tags:
split: val
predict_config:
transforms:
- class_path: rslearn.train.transforms.normalize.Normalize
init_args:
mean: 0
std: 3000
valid_range: [0, 1]
bands: [0, 1, 2]
selectors: ["image1", "image2", "image3"]
- class_path: rslearn.train.transforms.normalize.Normalize
init_args:
mean: 0
std: 8160
valid_range: [0, 1]
bands: [3, 4, 5, 6, 7, 8]
selectors: ["image1", "image2", "image3"]
- class_path: rslearn.train.transforms.concatenate.Concatenate
init_args:
selections:
image1: []
image2: []
image3: []
output_selector: image
groups: ["predict"]
load_all_patches: true
skip_targets: true
patch_size: 512
trainer:
max_epochs: 500
callbacks:
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
init_args:
logging_interval: "epoch"
- class_path: rslearn.train.prediction_writer.RslearnWriter
init_args:
path: gs://rslearn-eai/datasets/marine_infra/dataset_v1/live/
output_layer: output
selector: ["detect"]
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
init_args:
save_top_k: 1
save_last: true
monitor: val_detect/mAP
mode: max
- class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze
init_args:
module_selector: ["model", "encoder", 0, "encoder", "model"]
unfreeze_at_epoch: 2
rslp_project: satlas_marine_infra
rslp_experiment: data_20241002_run_20241210_00
Loading

0 comments on commit 7085a2f

Please sign in to comment.