diff --git a/README.md b/README.md
index 21bc556..c9a8232 100644
--- a/README.md
+++ b/README.md
@@ -1,11 +1,9 @@
# Open Driving World Models (OpenDWM)
-[[中文简介](README_intro_zh.md)]
+[](https://youtu.be/j9RRj-xzOA4) [
](README_intro_zh.md)
https://github.com/user-attachments/assets/649d3b81-3b1f-44f9-9f51-4d1ed7756476
-[Video link](https://youtu.be/j9RRj-xzOA4)
-
Welcome to the OpenDWM project! This is an open-source initiative, focusing on autonomous driving video generation. Our mission is to provide a high-quality, controllable tool for generating autonomous driving videos using the latest technology. We aim to build a codebase that is both user-friendly and highly reusable, and hope to continuously improve the project through the collective wisdom of the community.
The driving world models generate multi-view images or videos of autonomous driving scenes based on text and road environment layout conditions. Whether it's the environment, weather conditions, vehicle type, or driving path, you can adjust them according to your needs.
@@ -49,6 +47,7 @@ python -m pip install torch==2.5.1 torchvision==0.20.1
```
Clone the repository, then install the dependencies.
+
```
cd DWM
git submodule update --init --recursive
@@ -62,8 +61,8 @@ Our cross-view temporal SD (CTSD) pipeline support loading the pretrained SD 2.1
| Base model | Text conditioned
driving generation | Text and layout (box, map)
conditioned driving generation |
| :-: | :-: | :-: |
| [SD 2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1) | [Config](configs/ctsd/multi_datasets/ctsd_21_tirda_nwao.json), [Download](http://103.237.29.236:10030/ctsd_21_tirda_nwao_30k.pth) | [Config](configs/ctsd/multi_datasets/ctsd_21_tirda_bm_nwa.json), [Download](http://103.237.29.236:10030/ctsd_21_tirda_bm_nwa_30k.pth) |
-| [SD 3.0](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) | | [UniMLVG Config](configs/ctsd/unimlvg/unimlvg_stage3_tirda_nwa.json), [Download](http://103.237.29.236:10030/ctsd_unimlvg_tirda_bm_nwa_60k.pth) |
-| [SD 3.5](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium) | [Config](configs/ctsd/multi_datasets/ctsd_35_tirda_nwao.json), [Download](http://103.237.29.236:10030/ctsd_35_tirda_nwao_20k.pth) | [Config](configs/ctsd/multi_datasets/ctsd_35_tirda_bm_nwa.json), Released by 2025-3-1 |
+| [SD 3.0](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) | | [UniMLVG Config](configs/ctsd/unimlvg/ctsd_unimlvg_stage3_tirda_bm_nwa.json), [Download](http://103.237.29.236:10030/ctsd_unimlvg_tirda_bm_nwa_60k.pth) |
+| [SD 3.5](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium) | [Config](configs/ctsd/multi_datasets/ctsd_35_tirda_nwao.json), [Download](http://103.237.29.236:10030/ctsd_35_tirda_nwao_20k.pth) | [Config](configs/ctsd/multi_datasets/ctsd_35_tirda_bm_nwao.json), [Download](http://103.237.29.236:10030/ctsd_35_tirda_bm_nwao_40k.pth) |
## Examples
@@ -71,18 +70,18 @@ Our cross-view temporal SD (CTSD) pipeline support loading the pretrained SD 2.1
Download base model (for VAE, text encoders, scheduler config) and driving generation model checkpoint, and edit the [path](examples/ctsd_35_6views_image_generation.json#L102) and [prompts](examples/ctsd_35_6views_image_generation.json#L221) in the JSON config, then run this command.
-```
+```bash
PYTHONPATH=src python examples/ctsd_generation_example.py -c examples/ctsd_35_6views_image_generation.json -o output/ctsd_35_6views_image_generation
```
### Layout conditioned T2V generation with CTSD pipeline
-1. Download base model (for VAE, text encoders, scheduler config) and driving generation model checkpoint, and edit the [path](examples/ctsd_21_6views_video_generation_with_layout.json#L119) in the JSON config.
-2. Download layout resource package [nuscenes_scene-0627_package.zip](http://103.237.29.236:10030/nuscenes_scene-0627_package.zip) and unzip to the `{RESOURCE_PATH}`. Then edit the meta [path](examples/ctsd_21_6views_video_generation_with_layout.json#L129) as `{RESOURCE_PATH}/data.json` in the JSON config.
+1. Download base model (for VAE, text encoders, scheduler config) and driving generation model checkpoint, and edit the [path](examples/ctsd_35_6views_video_generation_with_layout.json#L156) in the JSON config.
+2. Download layout resource package ([nuscenes_scene-0627_package.zip](http://103.237.29.236:10030/nuscenes_scene-0627_package.zip), or [carla_town04_package](http://103.237.29.236:10030/carla_town04_package.zip)) and unzip to the `{RESOURCE_PATH}`. Then edit the meta [path](examples/ctsd_35_6views_video_generation_with_layout.json#L162) as `{RESOURCE_PATH}/data.json` in the JSON config.
3. Run this command to generate the video.
-```
-PYTHONPATH=src python src/dwm/preview.py -c examples/ctsd_unimlvg_6views_video_generation.json -o output/ctsd_unimlvg_6views_video_generation
+```bash
+PYTHONPATH=src python src/dwm/preview.py -c examples/ctsd_35_6views_video_generation_with_layout.json -o output/ctsd_35_6views_video_generation_with_layout
```
## Train
diff --git a/configs/ctsd/multi_datasets/ctsd_35_tirda_bm_nwa.json b/configs/ctsd/multi_datasets/ctsd_35_tirda_bm_nwa.json
deleted file mode 100644
index 982a078..0000000
--- a/configs/ctsd/multi_datasets/ctsd_35_tirda_bm_nwa.json
+++ /dev/null
@@ -1,828 +0,0 @@
-{
- "device": "cuda",
- "ddp_backend": "nccl",
- "train_epochs": 6,
- "generator_seed": 0,
- "data_shuffle": true,
- "fix_training_data_order": true,
- "global_state": {
- "nuscenes_fs": {
- "_class_name": "dwm.fs.czip.CombinedZipFileSystem",
- "fs": {
- "_class_name": "dwm.fs.dirfs.DirFileSystem",
- "path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan"
- },
- "paths": [
- "workspaces/worldmodels/data/nuscenes/interp_12Hz_trainval.zip",
- "data/nuscenes/v1.0-trainval01_blobs.zip",
- "data/nuscenes/v1.0-trainval02_blobs.zip",
- "data/nuscenes/v1.0-trainval03_blobs.zip",
- "data/nuscenes/v1.0-trainval04_blobs.zip",
- "data/nuscenes/v1.0-trainval05_blobs.zip",
- "data/nuscenes/v1.0-trainval06_blobs.zip",
- "data/nuscenes/v1.0-trainval07_blobs.zip",
- "data/nuscenes/v1.0-trainval08_blobs.zip",
- "data/nuscenes/v1.0-trainval09_blobs.zip",
- "data/nuscenes/v1.0-trainval10_blobs.zip",
- "data/nuscenes/nuScenes-map-expansion-v1.3.zip"
- ]
- },
- "device_mesh": {
- "_class_name": "torch.distributed.device_mesh.init_device_mesh",
- "device_type": "cuda",
- "mesh_shape": [
- 4,
- 8
- ]
- }
- },
- "optimizer": {
- "_class_name": "torch.optim.AdamW",
- "lr": 6e-5,
- "betas": [
- 0.9,
- 0.975
- ]
- },
- "pipeline": {
- "_class_name": "dwm.pipelines.ctsd.CrossviewTemporalSD",
- "common_config": {
- "frame_prediction_style": "ctsd",
- "cat_condition": true,
- "cond_with_action": false,
- "condition_on_all_frames": true,
- "added_time_ids": "fps_camera_transforms",
- "camera_intrinsic_embedding_indices": [
- 0,
- 4,
- 2,
- 5
- ],
- "camera_intrinsic_denom_embedding_indices": [
- 1,
- 1,
- 0,
- 1
- ],
- "camera_transform_embedding_indices": [
- 2,
- 6,
- 10,
- 3,
- 7,
- 11
- ],
- "distribution_framework": "fsdp",
- "ddp_wrapper_settings": {
- "sharding_strategy": {
- "_class_name": "torch.distributed.fsdp.ShardingStrategy",
- "value": 4
- },
- "device_mesh": {
- "_class_name": "dwm.common.get_state",
- "key": "device_mesh"
- },
- "auto_wrap_policy": {
- "_class_name": "torch.distributed.fsdp.wrap.ModuleWrapPolicy",
- "module_classes": [
- {
- "_class_name": "get_class",
- "class_name": "diffusers.models.attention.JointTransformerBlock"
- },
- {
- "_class_name": "get_class",
- "class_name": "dwm.models.crossview_temporal.VTSelfAttentionBlock"
- }
- ]
- },
- "mixed_precision": {
- "_class_name": "torch.distributed.fsdp.MixedPrecision",
- "param_dtype": {
- "_class_name": "get_class",
- "class_name": "torch.float16"
- }
- }
- },
- "t5_fsdp_wrapper_settings": {
- "sharding_strategy": {
- "_class_name": "torch.distributed.fsdp.ShardingStrategy",
- "value": 4
- },
- "device_mesh": {
- "_class_name": "dwm.common.get_state",
- "key": "device_mesh"
- },
- "auto_wrap_policy": {
- "_class_name": "torch.distributed.fsdp.wrap.ModuleWrapPolicy",
- "module_classes": [
- {
- "_class_name": "get_class",
- "class_name": "transformers.models.t5.modeling_t5.T5Block"
- }
- ]
- }
- },
- "text_encoder_load_args": {
- "variant": "fp16",
- "torch_dtype": {
- "_class_name": "get_class",
- "class_name": "torch.float16"
- }
- },
- "memory_efficient_batch": 16
- },
- "training_config": {
- "text_prompt_condition_ratio": 0.8,
- "3dbox_condition_ratio": 0.8,
- "hdmap_condition_ratio": 0.8,
- "reference_frame_count": 3,
- "generation_task_ratio": 0.25,
- "image_generation_ratio": 0.3,
- "all_reference_visible_ratio": 1,
- "reference_frame_scale_std": 0.02,
- "reference_frame_offset_std": 0.02,
- "enable_grad_scaler": true
- },
- "inference_config": {
- "guidance_scale": 4,
- "inference_steps": 40,
- "preview_image_size": [
- 448,
- 252
- ],
- "sequence_length_per_iteration": 19,
- "reference_frame_count": 3,
- "autoregression_data_exception_for_take_sequence": [
- "crossview_mask"
- ],
- "evaluation_item_count": 480
- },
- "model": {
- "_class_name": "dwm.models.crossview_temporal_dit.DiTCrossviewTemporalConditionModel",
- "dual_attention_layers": [
- 0,
- 1,
- 2,
- 3,
- 4,
- 5,
- 6,
- 7,
- 8,
- 9,
- 10,
- 11,
- 12
- ],
- "attention_head_dim": 64,
- "caption_projection_dim": 1536,
- "in_channels": 16,
- "joint_attention_dim": 4096,
- "num_attention_heads": 24,
- "num_layers": 24,
- "out_channels": 16,
- "patch_size": 2,
- "pooled_projection_dim": 2048,
- "pos_embed_max_size": 384,
- "qk_norm": "rms_norm",
- "qk_norm_on_additional_modules": "rms_norm",
- "sample_size": 128,
- "perspective_modeling_type": "implicit",
- "projection_class_embeddings_input_dim": 2816,
- "enable_crossview": true,
- "crossview_attention_type": "rowwise",
- "crossview_block_layers": [
- 1,
- 5,
- 9,
- 13,
- 17,
- 21
- ],
- "crossview_gradient_checkpointing": true,
- "enable_temporal": true,
- "temporal_attention_type": "rowwise",
- "temporal_block_layers": [
- 2,
- 3,
- 6,
- 7,
- 10,
- 11,
- 14,
- 15,
- 18,
- 19,
- 22,
- 23
- ],
- "temporal_gradient_checkpointing": true,
- "mixer_type": "AlphaBlender",
- "merge_factor": 2,
- "condition_image_adapter_config": {
- "in_channels": 6,
- "channels": [
- 1536,
- 1536,
- 1536,
- 1536,
- 1536,
- 1536
- ],
- "is_downblocks": [
- true,
- false,
- false,
- false,
- false,
- false
- ],
- "num_res_blocks": 2,
- "downscale_factor": 8,
- "use_zero_convs": true
- }
- },
- "pretrained_model_name_or_path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/models/stable-diffusion-3.5-medium",
- "model_checkpoint_path": "/mnt/afs/user/wuzehuan/Tasks/ctsd_35_tirda_bm_nwa_warmup/checkpoints/5000.pth",
- "model_load_state_args": {
- "strict": false
- },
- "metrics": {
- "fid": {
- "_class_name": "torchmetrics.image.fid.FrechetInceptionDistance",
- "normalize": true
- },
- "fvd": {
- "_class_name": "dwm.metrics.fvd.FrechetVideoDistance",
- "inception_3d_checkpoint_path": "/mnt/afs/user/wuzehuan/Documents/DWM/externals/TATS/tats/fvd/i3d_pretrained_400.pt",
- "sequence_count": 16
- }
- }
- },
- "training_dataset": {
- "_class_name": "dwm.datasets.common.DatasetAdapter",
- "base_dataset": {
- "_class_name": "torch.utils.data.ConcatDataset",
- "datasets": [
- {
- "_class_name": "dwm.datasets.nuscenes.MotionDataset",
- "fs": {
- "_class_name": "dwm.common.get_state",
- "key": "nuscenes_fs"
- },
- "dataset_name": "interp_12Hz_trainval",
- "split": "train",
- "sequence_length": 19,
- "fps_stride_tuples": [
- [
- 10,
- 0.1
- ]
- ],
- "sensor_channels": [
- "LIDAR_TOP",
- "CAM_FRONT_LEFT",
- "CAM_FRONT",
- "CAM_FRONT_RIGHT",
- "CAM_BACK_RIGHT",
- "CAM_BACK",
- "CAM_BACK_LEFT"
- ],
- "keyframe_only": true,
- "enable_camera_transforms": true,
- "enable_ego_transforms": true,
- "_3dbox_image_settings": {},
- "hdmap_image_settings": {},
- "image_description_settings": {
- "path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/workspaces/worldmodels/data/nuscenes_v1.0-trainval_caption_v2_train.json",
- "time_list_dict_path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/workspaces/worldmodels/data/nuscenes_v1.0-trainval_caption_v2_times_train.json",
- "align_keys": [
- "time",
- "weather"
- ],
- "reorder_keys": true,
- "drop_rates": {
- "environment": 0.04,
- "objects": 0.08,
- "image_description": 0.16
- }
- },
- "stub_key_data_dict": {
- "crossview_mask": [
- "content",
- {
- "_class_name": "torch.tensor",
- "data": {
- "_class_name": "json.loads",
- "s": "[[1,1,0,0,0,1],[1,1,1,0,0,0],[0,1,1,1,0,0],[0,0,1,1,1,0],[0,0,0,1,1,1],[1,0,0,0,1,1]]"
- },
- "dtype": {
- "_class_name": "get_class",
- "class_name": "torch.bool"
- }
- }
- ]
- }
- },
- {
- "_class_name": "dwm.datasets.waymo.MotionDataset",
- "fs": {
- "_class_name": "dwm.fs.dirfs.DirFileSystem",
- "path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/data/waymo/waymo_open_dataset_v_1_4_3/training"
- },
- "info_dict_path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/data/waymo/waymo_open_dataset_v_1_4_3/training.info.json",
- "sequence_length": 19,
- "fps_stride_tuples": [
- [
- 10,
- 0.1
- ]
- ],
- "sensor_channels": [
- "LIDAR_TOP",
- "CAM_SIDE_LEFT",
- "CAM_FRONT_LEFT",
- "CAM_FRONT",
- "CAM_FRONT_RIGHT",
- "CAM_SIDE_RIGHT",
- "CAM_FRONT"
- ],
- "enable_camera_transforms": true,
- "enable_ego_transforms": true,
- "_3dbox_image_settings": {},
- "hdmap_image_settings": {},
- "image_description_settings": {
- "path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/workspaces/worldmodels/data/waymo_caption_v2_train.json",
- "time_list_dict_path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/workspaces/worldmodels/data/waymo_caption_v2_times_train.json",
- "align_keys": [
- "time",
- "weather"
- ],
- "reorder_keys": true,
- "drop_rates": {
- "environment": 0.04,
- "objects": 0.08,
- "image_description": 0.16
- }
- },
- "stub_key_data_dict": {
- "crossview_mask": [
- "content",
- {
- "_class_name": "torch.tensor",
- "data": {
- "_class_name": "json.loads",
- "s": "[[1,1,0,0,0,0],[1,1,1,0,0,0],[0,1,1,1,0,0],[0,0,1,1,1,0],[0,0,0,1,1,0],[0,0,0,0,0,1]]"
- },
- "dtype": {
- "_class_name": "get_class",
- "class_name": "torch.bool"
- }
- }
- ]
- }
- },
- {
- "_class_name": "dwm.datasets.argoverse.MotionDataset",
- "fs": {
- "_class_name": "dwm.fs.ctar.CombinedTarFileSystem",
- "fs": {
- "_class_name": "dwm.fs.dirfs.DirFileSystem",
- "path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan"
- },
- "paths": [
- "data/argoverse/av2/tars/sensor/train-000.tar",
- "data/argoverse/av2/tars/sensor/train-001.tar",
- "data/argoverse/av2/tars/sensor/train-002.tar",
- "data/argoverse/av2/tars/sensor/train-003.tar",
- "data/argoverse/av2/tars/sensor/train-004.tar",
- "data/argoverse/av2/tars/sensor/train-005.tar",
- "data/argoverse/av2/tars/sensor/train-006.tar",
- "data/argoverse/av2/tars/sensor/train-007.tar",
- "data/argoverse/av2/tars/sensor/train-008.tar",
- "data/argoverse/av2/tars/sensor/train-009.tar",
- "data/argoverse/av2/tars/sensor/train-010.tar",
- "data/argoverse/av2/tars/sensor/train-011.tar",
- "data/argoverse/av2/tars/sensor/train-012.tar",
- "data/argoverse/av2/tars/sensor/train-013.tar"
- ],
- "enable_cached_info": true
- },
- "sequence_length": 19,
- "fps_stride_tuples": [
- [
- 10,
- 0.1
- ]
- ],
- "sensor_channels": [
- "lidar",
- "cameras/ring_front_left",
- "cameras/ring_front_right",
- "cameras/ring_side_right",
- "cameras/ring_rear_right",
- "cameras/ring_rear_left",
- "cameras/ring_side_left"
- ],
- "enable_camera_transforms": true,
- "enable_ego_transforms": true,
- "_3dbox_image_settings": {},
- "hdmap_image_settings": {},
- "image_description_settings": {
- "path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/workspaces/worldmodels/data/av2_sensor_caption_v2_train.json",
- "time_list_dict_path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/workspaces/worldmodels/data/av2_sensor_caption_v2_times_train.json",
- "align_keys": [
- "time",
- "weather"
- ],
- "reorder_keys": true,
- "drop_rates": {
- "environment": 0.04,
- "objects": 0.08,
- "image_description": 0.16
- }
- },
- "stub_key_data_dict": {
- "crossview_mask": [
- "content",
- {
- "_class_name": "torch.tensor",
- "data": {
- "_class_name": "json.loads",
- "s": "[[1,1,0,0,0,1],[1,1,1,0,0,0],[0,1,1,1,0,0],[0,0,1,1,1,0],[0,0,0,1,1,1],[1,0,0,0,1,1]]"
- },
- "dtype": {
- "_class_name": "get_class",
- "class_name": "torch.bool"
- }
- }
- ]
- }
- }
- ]
- },
- "transform_list": [
- {
- "old_key": "images",
- "new_key": "vae_images",
- "transform": {
- "_class_name": "torchvision.transforms.Compose",
- "transforms": [
- {
- "_class_name": "torchvision.transforms.Resize",
- "size": [
- 256,
- 448
- ]
- },
- {
- "_class_name": "torchvision.transforms.ToTensor"
- }
- ]
- }
- },
- {
- "old_key": "3dbox_images",
- "new_key": "3dbox_images",
- "transform": {
- "_class_name": "torchvision.transforms.Compose",
- "transforms": [
- {
- "_class_name": "torchvision.transforms.Resize",
- "size": [
- 256,
- 448
- ]
- },
- {
- "_class_name": "torchvision.transforms.ToTensor"
- }
- ]
- }
- },
- {
- "old_key": "hdmap_images",
- "new_key": "hdmap_images",
- "transform": {
- "_class_name": "torchvision.transforms.Compose",
- "transforms": [
- {
- "_class_name": "torchvision.transforms.Resize",
- "size": [
- 256,
- 448
- ]
- },
- {
- "_class_name": "torchvision.transforms.ToTensor"
- }
- ]
- }
- },
- {
- "old_key": "image_description",
- "new_key": "clip_text",
- "transform": {
- "_class_name": "dwm.datasets.common.Copy"
- },
- "stack": false
- }
- ],
- "pop_list": [
- "images",
- "lidar_points",
- "image_description"
- ]
- },
- "validation_dataset": {
- "_class_name": "dwm.datasets.common.DatasetAdapter",
- "base_dataset": {
- "_class_name": "torch.utils.data.ConcatDataset",
- "datasets": [
- {
- "_class_name": "dwm.datasets.nuscenes.MotionDataset",
- "fs": {
- "_class_name": "dwm.common.get_state",
- "key": "nuscenes_fs"
- },
- "dataset_name": "interp_12Hz_trainval",
- "split": "val",
- "sequence_length": 35,
- "fps_stride_tuples": [
- [
- 10,
- 20
- ]
- ],
- "sensor_channels": [
- "LIDAR_TOP",
- "CAM_FRONT_LEFT",
- "CAM_FRONT",
- "CAM_FRONT_RIGHT",
- "CAM_BACK_RIGHT",
- "CAM_BACK",
- "CAM_BACK_LEFT"
- ],
- "keyframe_only": true,
- "enable_synchronization_check": false,
- "enable_camera_transforms": true,
- "enable_ego_transforms": true,
- "_3dbox_image_settings": {},
- "hdmap_image_settings": {},
- "image_description_settings": {
- "path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/workspaces/worldmodels/data/nuscenes_v1.0-trainval_caption_v2_val.json",
- "time_list_dict_path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/workspaces/worldmodels/data/nuscenes_v1.0-trainval_caption_v2_times_val.json",
- "align_keys": [
- "time",
- "weather"
- ]
- },
- "stub_key_data_dict": {
- "crossview_mask": [
- "content",
- {
- "_class_name": "torch.tensor",
- "data": {
- "_class_name": "json.loads",
- "s": "[[1,1,0,0,0,1],[1,1,1,0,0,0],[0,1,1,1,0,0],[0,0,1,1,1,0],[0,0,0,1,1,1],[1,0,0,0,1,1]]"
- },
- "dtype": {
- "_class_name": "get_class",
- "class_name": "torch.bool"
- }
- }
- ]
- }
- },
- {
- "_class_name": "dwm.datasets.waymo.MotionDataset",
- "fs": {
- "_class_name": "dwm.fs.dirfs.DirFileSystem",
- "path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/data/waymo/waymo_open_dataset_v_1_4_3/validation"
- },
- "info_dict_path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/data/waymo/waymo_open_dataset_v_1_4_3/validation.info.json",
- "sequence_length": 35,
- "fps_stride_tuples": [
- [
- 10,
- 20
- ]
- ],
- "sensor_channels": [
- "LIDAR_TOP",
- "CAM_SIDE_LEFT",
- "CAM_FRONT_LEFT",
- "CAM_FRONT",
- "CAM_FRONT_RIGHT",
- "CAM_SIDE_RIGHT",
- "CAM_FRONT"
- ],
- "enable_camera_transforms": true,
- "enable_ego_transforms": true,
- "_3dbox_image_settings": {},
- "hdmap_image_settings": {},
- "image_description_settings": {
- "path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/workspaces/worldmodels/data/waymo_caption_v2_val.json",
- "time_list_dict_path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/workspaces/worldmodels/data/waymo_caption_v2_times_val.json",
- "align_keys": [
- "time",
- "weather"
- ]
- },
- "stub_key_data_dict": {
- "crossview_mask": [
- "content",
- {
- "_class_name": "torch.tensor",
- "data": {
- "_class_name": "json.loads",
- "s": "[[1,1,0,0,0,0],[1,1,1,0,0,0],[0,1,1,1,0,0],[0,0,1,1,1,0],[0,0,0,1,1,0],[0,0,0,0,0,1]]"
- },
- "dtype": {
- "_class_name": "get_class",
- "class_name": "torch.bool"
- }
- }
- ]
- }
- },
- {
- "_class_name": "dwm.datasets.argoverse.MotionDataset",
- "fs": {
- "_class_name": "dwm.fs.ctar.CombinedTarFileSystem",
- "fs": {
- "_class_name": "dwm.fs.dirfs.DirFileSystem",
- "path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan"
- },
- "paths": [
- "data/argoverse/av2/tars/sensor/val-000.tar",
- "data/argoverse/av2/tars/sensor/val-001.tar",
- "data/argoverse/av2/tars/sensor/val-002.tar"
- ],
- "enable_cached_info": true
- },
- "sequence_length": 35,
- "fps_stride_tuples": [
- [
- 10,
- 20
- ]
- ],
- "sensor_channels": [
- "lidar",
- "cameras/ring_front_left",
- "cameras/ring_front_right",
- "cameras/ring_side_right",
- "cameras/ring_rear_right",
- "cameras/ring_rear_left",
- "cameras/ring_side_left"
- ],
- "enable_camera_transforms": true,
- "enable_ego_transforms": true,
- "_3dbox_image_settings": {},
- "hdmap_image_settings": {},
- "image_description_settings": {
- "path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/workspaces/worldmodels/data/av2_sensor_caption_v2_val.json",
- "time_list_dict_path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/workspaces/worldmodels/data/av2_sensor_caption_v2_times_val.json",
- "align_keys": [
- "time",
- "weather"
- ]
- },
- "stub_key_data_dict": {
- "crossview_mask": [
- "content",
- {
- "_class_name": "torch.tensor",
- "data": {
- "_class_name": "json.loads",
- "s": "[[1,1,0,0,0,1],[1,1,1,0,0,0],[0,1,1,1,0,0],[0,0,1,1,1,0],[0,0,0,1,1,1],[1,0,0,0,1,1]]"
- },
- "dtype": {
- "_class_name": "get_class",
- "class_name": "torch.bool"
- }
- }
- ]
- }
- }
- ]
- },
- "transform_list": [
- {
- "old_key": "images",
- "new_key": "vae_images",
- "transform": {
- "_class_name": "torchvision.transforms.Compose",
- "transforms": [
- {
- "_class_name": "torchvision.transforms.Resize",
- "size": [
- 256,
- 448
- ]
- },
- {
- "_class_name": "torchvision.transforms.ToTensor"
- }
- ]
- }
- },
- {
- "old_key": "3dbox_images",
- "new_key": "3dbox_images",
- "transform": {
- "_class_name": "torchvision.transforms.Compose",
- "transforms": [
- {
- "_class_name": "torchvision.transforms.Resize",
- "size": [
- 256,
- 448
- ]
- },
- {
- "_class_name": "torchvision.transforms.ToTensor"
- }
- ]
- }
- },
- {
- "old_key": "hdmap_images",
- "new_key": "hdmap_images",
- "transform": {
- "_class_name": "torchvision.transforms.Compose",
- "transforms": [
- {
- "_class_name": "torchvision.transforms.Resize",
- "size": [
- 256,
- 448
- ]
- },
- {
- "_class_name": "torchvision.transforms.ToTensor"
- }
- ]
- }
- },
- {
- "old_key": "image_description",
- "new_key": "clip_text",
- "transform": {
- "_class_name": "dwm.datasets.common.Copy"
- },
- "stack": false
- }
- ],
- "pop_list": [
- "images",
- "lidar_points",
- "image_description"
- ]
- },
- "training_dataloader": {
- "batch_size": 2,
- "num_workers": 3,
- "prefetch_factor": 3,
- "collate_fn": {
- "_class_name": "dwm.datasets.common.CollateFnIgnoring",
- "keys": [
- "clip_text"
- ]
- },
- "persistent_workers": true
- },
- "validation_dataloader": {
- "batch_size": 1,
- "num_workers": 1,
- "prefetch_factor": 3,
- "collate_fn": {
- "_class_name": "dwm.datasets.common.CollateFnIgnoring",
- "keys": [
- "clip_text"
- ]
- },
- "persistent_workers": true
- },
- "preview_dataloader": {
- "batch_size": 1,
- "num_workers": 1,
- "prefetch_factor": 1,
- "shuffle": true,
- "drop_last": true,
- "collate_fn": {
- "_class_name": "dwm.datasets.common.CollateFnIgnoring",
- "keys": [
- "clip_text"
- ]
- },
- "persistent_workers": true
- },
- "informations": {
- "fid": -1,
- "fvd": -1,
- "total_batch_sizes": 64,
- "steps": 30000
- }
-}
\ No newline at end of file
diff --git a/configs/ctsd/multi_datasets/ctsd_35_tirda_bm_nwa_warmup.json b/configs/ctsd/multi_datasets/ctsd_35_tirda_bm_nwao.json
similarity index 77%
rename from configs/ctsd/multi_datasets/ctsd_35_tirda_bm_nwa_warmup.json
rename to configs/ctsd/multi_datasets/ctsd_35_tirda_bm_nwao.json
index 1a9a82b..1a63d4d 100644
--- a/configs/ctsd/multi_datasets/ctsd_35_tirda_bm_nwa_warmup.json
+++ b/configs/ctsd/multi_datasets/ctsd_35_tirda_bm_nwao.json
@@ -1,8 +1,8 @@
{
"device": "cuda",
"ddp_backend": "nccl",
- "train_epochs": 1,
- "generator_seed": 0,
+ "train_epochs": 4,
+ "generator_seed": 1,
"data_shuffle": true,
"fix_training_data_order": true,
"global_state": {
@@ -27,6 +27,19 @@
"data/nuscenes/nuScenes-map-expansion-v1.3.zip"
]
},
+ "opendv_czip_fs": {
+ "_class_name": "dwm.fs.czip.CombinedZipFileSystem",
+ "fs": {
+ "_class_name": "dwm.fs.dirfs.DirFileSystem",
+ "path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/data/opendv"
+ },
+ "paths": [
+ "opendv-youtube-10hz-720_0.zip",
+ "opendv-youtube-10hz-720_1.zip",
+ "opendv-youtube-10hz-720_2.zip",
+ "opendv-youtube-10hz-720_3.zip"
+ ]
+ },
"device_mesh": {
"_class_name": "torch.distributed.device_mesh.init_device_mesh",
"device_type": "cuda",
@@ -38,7 +51,11 @@
},
"optimizer": {
"_class_name": "torch.optim.AdamW",
- "lr": 1.2e-4
+ "lr": 6e-5,
+ "betas": [
+ 0.9,
+ 0.975
+ ]
},
"pipeline": {
"_class_name": "dwm.pipelines.ctsd.CrossviewTemporalSD",
@@ -47,6 +64,7 @@
"cat_condition": true,
"cond_with_action": false,
"condition_on_all_frames": true,
+ "uncondition_image_color": 0.1255,
"added_time_ids": "fps_camera_transforms",
"camera_intrinsic_embedding_indices": [
0,
@@ -97,8 +115,7 @@
"_class_name": "get_class",
"class_name": "torch.float16"
}
- },
- "use_orig_params": true
+ }
},
"t5_fsdp_wrapper_settings": {
"sharding_strategy": {
@@ -126,19 +143,22 @@
"class_name": "torch.float16"
}
},
- "memory_efficient_batch": 16
+ "memory_efficient_batch": 12
},
"training_config": {
- "freezing_pattern": "^(transformer_blocks|time_text_embed)$",
"text_prompt_condition_ratio": 0.8,
"3dbox_condition_ratio": 0.8,
"hdmap_condition_ratio": 0.8,
- "reference_frame_count": 3,
+ "reference_frame_count": {
+ "1": 0.1,
+ "2": 0.3,
+ "3": 0.6
+ },
"generation_task_ratio": 0.25,
"image_generation_ratio": 0.3,
"all_reference_visible_ratio": 1,
- "reference_frame_scale_std": 0.02,
- "reference_frame_offset_std": 0.02,
+ "reference_frame_scale_std": 0.01,
+ "reference_frame_offset_std": 0.01,
"enable_grad_scaler": true
},
"inference_config": {
@@ -153,7 +173,7 @@
"autoregression_data_exception_for_take_sequence": [
"crossview_mask"
],
- "evaluation_item_count": 480
+ "evaluation_item_count": 640
},
"model": {
"_class_name": "dwm.models.crossview_temporal_dit.DiTCrossviewTemporalConditionModel",
@@ -199,7 +219,7 @@
],
"crossview_gradient_checkpointing": true,
"enable_temporal": true,
- "temporal_attention_type": "rowwise",
+ "temporal_attention_type": "pointwise",
"temporal_block_layers": [
2,
3,
@@ -456,6 +476,95 @@
}
]
}
+ },
+ {
+ "_class_name": "dwm.datasets.opendv.MotionDataset",
+ "fs": {
+ "_class_name": "dwm.common.get_state",
+ "key": "opendv_czip_fs"
+ },
+ "meta_path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/nijingcheng/datasets/OpenDV-YouTube.json",
+ "sequence_length": 19,
+ "fps_stride_tuples": [
+ [
+ 10,
+ 5
+ ]
+ ],
+ "split": "Train",
+ "mini_batch": 6,
+ "ignore_list": [
+ "izhGt1GnGFk"
+ ],
+ "enable_pts": false,
+ "enable_fake_camera_transforms": true,
+ "enable_fake_3dbox_images": true,
+ "enable_fake_hdmap_images": true,
+ "fake_condition_image_color": [
+ 32,
+ 32,
+ 32
+ ],
+ "image_description_settings": {
+ "path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/workspaces/worldmodels/data/opendv_caption.json",
+ "candidates_times_path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/workspaces/worldmodels/data/opendv_candidates_times.json",
+ "seed": 5,
+ "reorder_keys": true,
+ "drop_rates": {
+ "environment": 0.04,
+ "objects": 0.08,
+ "image_description": 0.16
+ }
+ },
+ "stub_key_data_dict": {
+ "pts": [
+ "content",
+ {
+ "_class_name": "torch.zeros",
+ "size": [
+ 19,
+ 7
+ ]
+ }
+ ],
+ "crossview_mask": [
+ "content",
+ {
+ "_class_name": "torch.eye",
+ "n": 6,
+ "dtype": {
+ "_class_name": "get_class",
+ "class_name": "torch.bool"
+ }
+ }
+ ],
+ "ego_transforms": [
+ "content",
+ {
+ "_class_name": "torch.repeat_interleave",
+ "input": {
+ "_class_name": "torch.repeat_interleave",
+ "input": {
+ "_class_name": "torch.reshape",
+ "input": {
+ "_class_name": "torch.eye",
+ "n": 4
+ },
+ "shape": [
+ 1,
+ 1,
+ 4,
+ 4
+ ]
+ },
+ "repeats": 7,
+ "dim": 1
+ },
+ "repeats": 19,
+ "dim": 0
+ }
+ ]
+ }
}
]
},
@@ -529,6 +638,7 @@
"pop_list": [
"images",
"lidar_points",
+ "lidar_transforms",
"image_description"
]
},
@@ -703,6 +813,85 @@
}
]
}
+ },
+ {
+ "_class_name": "dwm.datasets.opendv.MotionDataset",
+ "fs": {
+ "_class_name": "dwm.common.get_state",
+ "key": "opendv_czip_fs"
+ },
+ "meta_path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/nijingcheng/datasets/OpenDV-YouTube.json",
+ "sequence_length": 35,
+ "fps_stride_tuples": [
+ [
+ 10,
+ 75
+ ]
+ ],
+ "split": "Val",
+ "mini_batch": 6,
+ "enable_pts": false,
+ "enable_fake_camera_transforms": true,
+ "enable_fake_3dbox_images": true,
+ "enable_fake_hdmap_images": true,
+ "fake_condition_image_color": [
+ 32,
+ 32,
+ 32
+ ],
+ "image_description_settings": {
+ "path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/workspaces/worldmodels/data/opendv_caption.json",
+ "candidates_times_path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/workspaces/worldmodels/data/opendv_candidates_times.json"
+ },
+ "stub_key_data_dict": {
+ "pts": [
+ "content",
+ {
+ "_class_name": "torch.zeros",
+ "size": [
+ 35,
+ 7
+ ]
+ }
+ ],
+ "crossview_mask": [
+ "content",
+ {
+ "_class_name": "torch.eye",
+ "n": 6,
+ "dtype": {
+ "_class_name": "get_class",
+ "class_name": "torch.bool"
+ }
+ }
+ ],
+ "ego_transforms": [
+ "content",
+ {
+ "_class_name": "torch.repeat_interleave",
+ "input": {
+ "_class_name": "torch.repeat_interleave",
+ "input": {
+ "_class_name": "torch.reshape",
+ "input": {
+ "_class_name": "torch.eye",
+ "n": 4
+ },
+ "shape": [
+ 1,
+ 1,
+ 4,
+ 4
+ ]
+ },
+ "repeats": 7,
+ "dim": 1
+ },
+ "repeats": 35,
+ "dim": 0
+ }
+ ]
+ }
}
]
},
@@ -776,11 +965,43 @@
"pop_list": [
"images",
"lidar_points",
+ "lidar_transforms",
"image_description"
]
},
+ "mix_config": {
+ "256-448": [
+ 0.6,
+ [
+ [
+ 19,
+ 2,
+ 1.0
+ ]
+ ]
+ ],
+ "176-304": [
+ 0.3,
+ [
+ [
+ 19,
+ 4,
+ 1.0
+ ]
+ ]
+ ],
+ "144-256": [
+ 0.1,
+ [
+ [
+ 19,
+ 6,
+ 1.0
+ ]
+ ]
+ ]
+ },
"training_dataloader": {
- "batch_size": 2,
"num_workers": 3,
"prefetch_factor": 3,
"collate_fn": {
@@ -818,9 +1039,12 @@
"persistent_workers": true
},
"informations": {
- "fid": 14.75,
- "fvd": 268.69,
- "total_batch_sizes": 64,
- "steps": 5000
+ "fid": 9.46,
+ "fvd": 91.55,
+ "fvd_on_nusc_by_3_ref_frames": 38.83,
+ "fvd_on_nusc_by_1_ref_frames": 45.98,
+ "fvd_on_nusc_without_ref_frame": 117.50,
+ "average_total_batch_sizes": 96,
+ "steps": 40000
}
}
\ No newline at end of file
diff --git a/examples/ctsd_35_6views_video_generation_with_layout.json b/examples/ctsd_35_6views_video_generation_with_layout.json
new file mode 100644
index 0000000..bb0e83d
--- /dev/null
+++ b/examples/ctsd_35_6views_video_generation_with_layout.json
@@ -0,0 +1,292 @@
+{
+ "device": "cuda",
+ "generator_seed": 0,
+ "pipeline": {
+ "_class_name": "dwm.pipelines.ctsd.CrossviewTemporalSD",
+ "common_config": {
+ "frame_prediction_style": "ctsd",
+ "condition_on_all_frames": true,
+ "uncondition_image_color": 0.1255,
+ "added_time_ids": "fps_camera_transforms",
+ "camera_intrinsic_embedding_indices": [
+ 0,
+ 4,
+ 2,
+ 5
+ ],
+ "camera_intrinsic_denom_embedding_indices": [
+ 1,
+ 1,
+ 0,
+ 1
+ ],
+ "camera_transform_embedding_indices": [
+ 2,
+ 6,
+ 10,
+ 3,
+ 7,
+ 11
+ ],
+ "autocast": {
+ "device_type": "cuda"
+ },
+ "text_encoder_load_args": {
+ "variant": "fp16",
+ "torch_dtype": {
+ "_class_name": "get_class",
+ "class_name": "torch.float16"
+ },
+ "quantization_config": {
+ "_class_name": "diffusers.quantizers.quantization_config.BitsAndBytesConfig",
+ "load_in_4bit": true,
+ "bnb_4bit_quant_type": "nf4",
+ "bnb_4bit_compute_dtype": {
+ "_class_name": "get_class",
+ "class_name": "torch.float16"
+ }
+ }
+ },
+ "memory_efficient_batch": 12
+ },
+ "training_config": {},
+ "inference_config": {
+ "guidance_scale": 4,
+ "inference_steps": 40,
+ "preview_image_size": [
+ 448,
+ 252
+ ],
+ "sequence_length_per_iteration": 19,
+ "reference_frame_count": 3,
+ "autoregression_data_exception_for_take_sequence": [
+ "crossview_mask"
+ ]
+ },
+ "model": {
+ "_class_name": "dwm.models.crossview_temporal_dit.DiTCrossviewTemporalConditionModel",
+ "dual_attention_layers": [
+ 0,
+ 1,
+ 2,
+ 3,
+ 4,
+ 5,
+ 6,
+ 7,
+ 8,
+ 9,
+ 10,
+ 11,
+ 12
+ ],
+ "attention_head_dim": 64,
+ "caption_projection_dim": 1536,
+ "in_channels": 16,
+ "joint_attention_dim": 4096,
+ "num_attention_heads": 24,
+ "num_layers": 24,
+ "out_channels": 16,
+ "patch_size": 2,
+ "pooled_projection_dim": 2048,
+ "pos_embed_max_size": 384,
+ "qk_norm": "rms_norm",
+ "qk_norm_on_additional_modules": "rms_norm",
+ "sample_size": 128,
+ "perspective_modeling_type": "implicit",
+ "projection_class_embeddings_input_dim": 2816,
+ "enable_crossview": true,
+ "crossview_attention_type": "rowwise",
+ "crossview_block_layers": [
+ 1,
+ 5,
+ 9,
+ 13,
+ 17,
+ 21
+ ],
+ "crossview_gradient_checkpointing": true,
+ "enable_temporal": true,
+ "temporal_attention_type": "pointwise",
+ "temporal_block_layers": [
+ 2,
+ 3,
+ 6,
+ 7,
+ 10,
+ 11,
+ 14,
+ 15,
+ 18,
+ 19,
+ 22,
+ 23
+ ],
+ "temporal_gradient_checkpointing": true,
+ "mixer_type": "AlphaBlender",
+ "merge_factor": 2,
+ "condition_image_adapter_config": {
+ "in_channels": 6,
+ "channels": [
+ 1536,
+ 1536,
+ 1536,
+ 1536,
+ 1536,
+ 1536
+ ],
+ "is_downblocks": [
+ true,
+ false,
+ false,
+ false,
+ false,
+ false
+ ],
+ "num_res_blocks": 2,
+ "downscale_factor": 8,
+ "use_zero_convs": true
+ }
+ },
+ "model_dtype": {
+ "_class_name": "get_class",
+ "class_name": "torch.float16"
+ },
+ "pretrained_model_name_or_path": "/mnt/afs/user/wuzehuan/Downloads/models/stable-diffusion-3.5-medium",
+ "model_checkpoint_path": "/mnt/afs/user/wuzehuan/Tasks/ctsd_35_tirda_bm_nwao/checkpoints/40000.pth"
+ },
+ "validation_dataset": {
+ "_class_name": "dwm.datasets.common.DatasetAdapter",
+ "base_dataset": {
+ "_class_name": "dwm.datasets.preview.PreviewDataset",
+ "json_file": "/mnt/afs/user/wuzehuan/Documents/DWM/output/carla_town04_package/data.json",
+ "sequence_length": 179,
+ "fps_stride_tuples": [
+ [
+ 10,
+ 60
+ ]
+ ],
+ "sensor_channels": [
+ "CAM_FRONT_LEFT",
+ "CAM_FRONT",
+ "CAM_FRONT_RIGHT",
+ "CAM_BACK_RIGHT",
+ "CAM_BACK",
+ "CAM_BACK_LEFT"
+ ],
+ "enable_camera_transforms": true,
+ "use_hdmap": true,
+ "use_3dbox": true,
+ "drop_vehicle_color": true,
+ "stub_key_data_dict": {
+ "crossview_mask": [
+ "content",
+ {
+ "_class_name": "torch.tensor",
+ "data": {
+ "_class_name": "json.loads",
+ "s": "[[1,1,0,0,0,1],[1,1,1,0,0,0],[0,1,1,1,0,0],[0,0,1,1,1,0],[0,0,0,1,1,1],[1,0,0,0,1,1]]"
+ },
+ "dtype": {
+ "_class_name": "get_class",
+ "class_name": "torch.bool"
+ }
+ }
+ ]
+ }
+ },
+ "transform_list": [
+ {
+ "old_key": "images",
+ "new_key": "vae_images",
+ "transform": {
+ "_class_name": "torchvision.transforms.Compose",
+ "transforms": [
+ {
+ "_class_name": "torchvision.transforms.Resize",
+ "size": [
+ 256,
+ 448
+ ]
+ },
+ {
+ "_class_name": "torchvision.transforms.ToTensor"
+ }
+ ]
+ }
+ },
+ {
+ "old_key": "3dbox_images",
+ "new_key": "3dbox_images",
+ "transform": {
+ "_class_name": "torchvision.transforms.Compose",
+ "transforms": [
+ {
+ "_class_name": "torchvision.transforms.Resize",
+ "size": [
+ 256,
+ 448
+ ]
+ },
+ {
+ "_class_name": "torchvision.transforms.ToTensor"
+ }
+ ]
+ }
+ },
+ {
+ "old_key": "hdmap_images",
+ "new_key": "hdmap_images",
+ "transform": {
+ "_class_name": "torchvision.transforms.Compose",
+ "transforms": [
+ {
+ "_class_name": "torchvision.transforms.Resize",
+ "size": [
+ 256,
+ 448
+ ]
+ },
+ {
+ "_class_name": "torchvision.transforms.ToTensor"
+ }
+ ]
+ }
+ },
+ {
+ "old_key": "image_description",
+ "new_key": "clip_text",
+ "transform": {
+ "_class_name": "dwm.datasets.common.Copy"
+ },
+ "stack": false
+ }
+ ],
+ "pop_list": [
+ "images",
+ "pred_images",
+ "image_description"
+ ]
+ },
+ "validation_dataloader": {
+ "batch_size": 1,
+ "num_workers": 1,
+ "collate_fn": {
+ "_class_name": "dwm.datasets.common.CollateFnIgnoring",
+ "keys": [
+ "clip_text"
+ ]
+ }
+ },
+ "preview_dataloader": {
+ "batch_size": 1,
+ "num_workers": 1,
+ "collate_fn": {
+ "_class_name": "dwm.datasets.common.CollateFnIgnoring",
+ "keys": [
+ "clip_text"
+ ]
+ }
+ }
+}
\ No newline at end of file
diff --git a/src/dwm/datasets/common.py b/src/dwm/datasets/common.py
index b3c6872..8287dc6 100644
--- a/src/dwm/datasets/common.py
+++ b/src/dwm/datasets/common.py
@@ -51,9 +51,6 @@ def apply_transform(transform, a, stack: bool = True):
else:
return transform(a)
- def apply_temporal_transform(transform, a):
- return transform(a)
-
def __init__(
self, base_dataset: torch.utils.data.Dataset, transform_list: list,
pop_list=None
@@ -70,9 +67,8 @@ def __getitem__(self, index):
if isinstance(index, int):
item = self.base_dataset[index]
for i in self.transform_list:
- if getattr(i["transform"], 'is_temporal_transform', False):
- item[i["new_key"]] = DatasetAdapter.apply_temporal_transform(
- i["transform"], item[i["old_key"]])
+ if i.get("is_dynamic_transform", False):
+ item = i["transform"](item)
else:
item[i["new_key"]] = DatasetAdapter.apply_transform(
i["transform"], item[i["old_key"]],
@@ -313,7 +309,6 @@ def make_image_description_string(caption_dict: dict, settings: dict):
probabilities of the corresponding key elements in the
caption_dict being dropped.
"""
-
default_image_description_keys = [
"time", "weather", "environment", "objects", "image_description"
]
diff --git a/src/dwm/datasets/nuscenes.py b/src/dwm/datasets/nuscenes.py
index 5ba506d..df32511 100755
--- a/src/dwm/datasets/nuscenes.py
+++ b/src/dwm/datasets/nuscenes.py
@@ -732,7 +732,7 @@ def get_hdmap_bev_image(
MotionDataset.default_bev_from_ego_transform)
color_table = hdmap_bev_settings.get(
"color_table", MotionDataset.default_hdmap_color_table)
-
+ fill_map = hdmap_bev_settings.get("fill_map", True)
# get the transform from the world (map) space to the BEV space
world_from_ego = dwm.datasets.common.get_transform(
sample_data["rotation"], sample_data["translation"])
@@ -759,14 +759,14 @@ def get_hdmap_bev_image(
for polygon_token in i["polygon_tokens"]:
MotionDataset.draw_polygon_to_bev_image(
polygons[polygon_token], nodes, draw, bev_from_world,
- (0, 0, 255), pen_width, solid=True)
+ (0, 0, 255), pen_width, solid=fill_map)
if "ped_crossing" in color_table and "ped_crossing" in map:
pen_color = tuple(color_table["ped_crossing"])
for i in map["ped_crossing"]:
MotionDataset.draw_polygon_to_bev_image(
polygons[i["polygon_token"]], nodes, draw, bev_from_world,
- (255, 0, 0), pen_width, solid=True)
+ (255, 0, 0), pen_width, solid=fill_map)
if "lane" in color_table and "lane" in map:
pen_color = tuple(color_table["lane"])
diff --git a/src/dwm/fs/s3fs.py b/src/dwm/fs/s3fs.py
index 0ab8aec..74afd5b 100644
--- a/src/dwm/fs/s3fs.py
+++ b/src/dwm/fs/s3fs.py
@@ -138,7 +138,7 @@ def _open(
def ls(self, path, detail=True, **kwargs):
self.reinit_if_forked()
bucket, key = S3File.find_bucket_key(path)
- if not key.endswith("/"):
+ if len(key) > 0 and not key.endswith("/"):
key = key + "/"
# NOTE: only files are listed
@@ -151,7 +151,7 @@ def ls(self, path, detail=True, **kwargs):
else:
response = self.client.list_objects(
Bucket=bucket, Delimiter="/", Prefix=key,
- ContinuationToken=continuation_token)
+ Marker=continuation_token)
if "Contents" in response:
for i in response["Contents"]:
@@ -165,7 +165,7 @@ def ls(self, path, detail=True, **kwargs):
})
if response["IsTruncated"]:
- continuation_token = response["NextContinuationToken"]
+ continuation_token = response["NextMarker"]
else:
break
diff --git a/src/dwm/models/crossview_temporal_dit.py b/src/dwm/models/crossview_temporal_dit.py
index d0d1e50..06e223f 100644
--- a/src/dwm/models/crossview_temporal_dit.py
+++ b/src/dwm/models/crossview_temporal_dit.py
@@ -350,13 +350,13 @@ def forward_temporal_block_and_mix_result(
b=batch_size, v=view_count, w=width)
else:
temporal_hidden_states = einops.rearrange(
- temporal_hidden_states, "(b t v) hw c -> (bv hw) t c",
+ temporal_hidden_states, "(b t v) hw c -> (b v hw) t c",
b=batch_size, t=sequence_length)
temporal_hidden_states = temporal_block(
temporal_hidden_states)
temporal_hidden_states = einops.rearrange(
- temporal_hidden_states, "(bv hw) t c -> (b t v) hw c",
- b=batch_size, t=sequence_length)
+ temporal_hidden_states, "(b v hw) t c -> (b t v) hw c",
+ b=batch_size, v=view_count, t=sequence_length)
return mixer(
hidden_states.view(