diff --git a/README.md b/README.md index 21bc556..c9a8232 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,9 @@ # Open Driving World Models (OpenDWM) -[[中文简介](README_intro_zh.md)] +[![Youtube](https://badges.aleen42.com/src/youtube.svg)](https://youtu.be/j9RRj-xzOA4) [](README_intro_zh.md) https://github.com/user-attachments/assets/649d3b81-3b1f-44f9-9f51-4d1ed7756476 -[Video link](https://youtu.be/j9RRj-xzOA4) - Welcome to the OpenDWM project! This is an open-source initiative, focusing on autonomous driving video generation. Our mission is to provide a high-quality, controllable tool for generating autonomous driving videos using the latest technology. We aim to build a codebase that is both user-friendly and highly reusable, and hope to continuously improve the project through the collective wisdom of the community. The driving world models generate multi-view images or videos of autonomous driving scenes based on text and road environment layout conditions. Whether it's the environment, weather conditions, vehicle type, or driving path, you can adjust them according to your needs. @@ -49,6 +47,7 @@ python -m pip install torch==2.5.1 torchvision==0.20.1 ``` Clone the repository, then install the dependencies. + ``` cd DWM git submodule update --init --recursive @@ -62,8 +61,8 @@ Our cross-view temporal SD (CTSD) pipeline support loading the pretrained SD 2.1 | Base model | Text conditioned
driving generation | Text and layout (box, map)
conditioned driving generation | | :-: | :-: | :-: | | [SD 2.1](https://huggingface.co/stabilityai/stable-diffusion-2-1) | [Config](configs/ctsd/multi_datasets/ctsd_21_tirda_nwao.json), [Download](http://103.237.29.236:10030/ctsd_21_tirda_nwao_30k.pth) | [Config](configs/ctsd/multi_datasets/ctsd_21_tirda_bm_nwa.json), [Download](http://103.237.29.236:10030/ctsd_21_tirda_bm_nwa_30k.pth) | -| [SD 3.0](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) | | [UniMLVG Config](configs/ctsd/unimlvg/unimlvg_stage3_tirda_nwa.json), [Download](http://103.237.29.236:10030/ctsd_unimlvg_tirda_bm_nwa_60k.pth) | -| [SD 3.5](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium) | [Config](configs/ctsd/multi_datasets/ctsd_35_tirda_nwao.json), [Download](http://103.237.29.236:10030/ctsd_35_tirda_nwao_20k.pth) | [Config](configs/ctsd/multi_datasets/ctsd_35_tirda_bm_nwa.json), Released by 2025-3-1 | +| [SD 3.0](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers) | | [UniMLVG Config](configs/ctsd/unimlvg/ctsd_unimlvg_stage3_tirda_bm_nwa.json), [Download](http://103.237.29.236:10030/ctsd_unimlvg_tirda_bm_nwa_60k.pth) | +| [SD 3.5](https://huggingface.co/stabilityai/stable-diffusion-3.5-medium) | [Config](configs/ctsd/multi_datasets/ctsd_35_tirda_nwao.json), [Download](http://103.237.29.236:10030/ctsd_35_tirda_nwao_20k.pth) | [Config](configs/ctsd/multi_datasets/ctsd_35_tirda_bm_nwao.json), [Download](http://103.237.29.236:10030/ctsd_35_tirda_bm_nwao_40k.pth) | ## Examples @@ -71,18 +70,18 @@ Our cross-view temporal SD (CTSD) pipeline support loading the pretrained SD 2.1 Download base model (for VAE, text encoders, scheduler config) and driving generation model checkpoint, and edit the [path](examples/ctsd_35_6views_image_generation.json#L102) and [prompts](examples/ctsd_35_6views_image_generation.json#L221) in the JSON config, then run this command. -``` +```bash PYTHONPATH=src python examples/ctsd_generation_example.py -c examples/ctsd_35_6views_image_generation.json -o output/ctsd_35_6views_image_generation ``` ### Layout conditioned T2V generation with CTSD pipeline -1. Download base model (for VAE, text encoders, scheduler config) and driving generation model checkpoint, and edit the [path](examples/ctsd_21_6views_video_generation_with_layout.json#L119) in the JSON config. -2. Download layout resource package [nuscenes_scene-0627_package.zip](http://103.237.29.236:10030/nuscenes_scene-0627_package.zip) and unzip to the `{RESOURCE_PATH}`. Then edit the meta [path](examples/ctsd_21_6views_video_generation_with_layout.json#L129) as `{RESOURCE_PATH}/data.json` in the JSON config. +1. Download base model (for VAE, text encoders, scheduler config) and driving generation model checkpoint, and edit the [path](examples/ctsd_35_6views_video_generation_with_layout.json#L156) in the JSON config. +2. Download layout resource package ([nuscenes_scene-0627_package.zip](http://103.237.29.236:10030/nuscenes_scene-0627_package.zip), or [carla_town04_package](http://103.237.29.236:10030/carla_town04_package.zip)) and unzip to the `{RESOURCE_PATH}`. Then edit the meta [path](examples/ctsd_35_6views_video_generation_with_layout.json#L162) as `{RESOURCE_PATH}/data.json` in the JSON config. 3. Run this command to generate the video. -``` -PYTHONPATH=src python src/dwm/preview.py -c examples/ctsd_unimlvg_6views_video_generation.json -o output/ctsd_unimlvg_6views_video_generation +```bash +PYTHONPATH=src python src/dwm/preview.py -c examples/ctsd_35_6views_video_generation_with_layout.json -o output/ctsd_35_6views_video_generation_with_layout ``` ## Train diff --git a/configs/ctsd/multi_datasets/ctsd_35_tirda_bm_nwa.json b/configs/ctsd/multi_datasets/ctsd_35_tirda_bm_nwa.json deleted file mode 100644 index 982a078..0000000 --- a/configs/ctsd/multi_datasets/ctsd_35_tirda_bm_nwa.json +++ /dev/null @@ -1,828 +0,0 @@ -{ - "device": "cuda", - "ddp_backend": "nccl", - "train_epochs": 6, - "generator_seed": 0, - "data_shuffle": true, - "fix_training_data_order": true, - "global_state": { - "nuscenes_fs": { - "_class_name": "dwm.fs.czip.CombinedZipFileSystem", - "fs": { - "_class_name": "dwm.fs.dirfs.DirFileSystem", - "path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan" - }, - "paths": [ - "workspaces/worldmodels/data/nuscenes/interp_12Hz_trainval.zip", - "data/nuscenes/v1.0-trainval01_blobs.zip", - "data/nuscenes/v1.0-trainval02_blobs.zip", - "data/nuscenes/v1.0-trainval03_blobs.zip", - "data/nuscenes/v1.0-trainval04_blobs.zip", - "data/nuscenes/v1.0-trainval05_blobs.zip", - "data/nuscenes/v1.0-trainval06_blobs.zip", - "data/nuscenes/v1.0-trainval07_blobs.zip", - "data/nuscenes/v1.0-trainval08_blobs.zip", - "data/nuscenes/v1.0-trainval09_blobs.zip", - "data/nuscenes/v1.0-trainval10_blobs.zip", - "data/nuscenes/nuScenes-map-expansion-v1.3.zip" - ] - }, - "device_mesh": { - "_class_name": "torch.distributed.device_mesh.init_device_mesh", - "device_type": "cuda", - "mesh_shape": [ - 4, - 8 - ] - } - }, - "optimizer": { - "_class_name": "torch.optim.AdamW", - "lr": 6e-5, - "betas": [ - 0.9, - 0.975 - ] - }, - "pipeline": { - "_class_name": "dwm.pipelines.ctsd.CrossviewTemporalSD", - "common_config": { - "frame_prediction_style": "ctsd", - "cat_condition": true, - "cond_with_action": false, - "condition_on_all_frames": true, - "added_time_ids": "fps_camera_transforms", - "camera_intrinsic_embedding_indices": [ - 0, - 4, - 2, - 5 - ], - "camera_intrinsic_denom_embedding_indices": [ - 1, - 1, - 0, - 1 - ], - "camera_transform_embedding_indices": [ - 2, - 6, - 10, - 3, - 7, - 11 - ], - "distribution_framework": "fsdp", - "ddp_wrapper_settings": { - "sharding_strategy": { - "_class_name": "torch.distributed.fsdp.ShardingStrategy", - "value": 4 - }, - "device_mesh": { - "_class_name": "dwm.common.get_state", - "key": "device_mesh" - }, - "auto_wrap_policy": { - "_class_name": "torch.distributed.fsdp.wrap.ModuleWrapPolicy", - "module_classes": [ - { - "_class_name": "get_class", - "class_name": "diffusers.models.attention.JointTransformerBlock" - }, - { - "_class_name": "get_class", - "class_name": "dwm.models.crossview_temporal.VTSelfAttentionBlock" - } - ] - }, - "mixed_precision": { - "_class_name": "torch.distributed.fsdp.MixedPrecision", - "param_dtype": { - "_class_name": "get_class", - "class_name": "torch.float16" - } - } - }, - "t5_fsdp_wrapper_settings": { - "sharding_strategy": { - "_class_name": "torch.distributed.fsdp.ShardingStrategy", - "value": 4 - }, - "device_mesh": { - "_class_name": "dwm.common.get_state", - "key": "device_mesh" - }, - "auto_wrap_policy": { - "_class_name": "torch.distributed.fsdp.wrap.ModuleWrapPolicy", - "module_classes": [ - { - "_class_name": "get_class", - "class_name": "transformers.models.t5.modeling_t5.T5Block" - } - ] - } - }, - "text_encoder_load_args": { - "variant": "fp16", - "torch_dtype": { - "_class_name": "get_class", - "class_name": "torch.float16" - } - }, - "memory_efficient_batch": 16 - }, - "training_config": { - "text_prompt_condition_ratio": 0.8, - "3dbox_condition_ratio": 0.8, - "hdmap_condition_ratio": 0.8, - "reference_frame_count": 3, - "generation_task_ratio": 0.25, - "image_generation_ratio": 0.3, - "all_reference_visible_ratio": 1, - "reference_frame_scale_std": 0.02, - "reference_frame_offset_std": 0.02, - "enable_grad_scaler": true - }, - "inference_config": { - "guidance_scale": 4, - "inference_steps": 40, - "preview_image_size": [ - 448, - 252 - ], - "sequence_length_per_iteration": 19, - "reference_frame_count": 3, - "autoregression_data_exception_for_take_sequence": [ - "crossview_mask" - ], - "evaluation_item_count": 480 - }, - "model": { - "_class_name": "dwm.models.crossview_temporal_dit.DiTCrossviewTemporalConditionModel", - "dual_attention_layers": [ - 0, - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - 9, - 10, - 11, - 12 - ], - "attention_head_dim": 64, - "caption_projection_dim": 1536, - "in_channels": 16, - "joint_attention_dim": 4096, - "num_attention_heads": 24, - "num_layers": 24, - "out_channels": 16, - "patch_size": 2, - "pooled_projection_dim": 2048, - "pos_embed_max_size": 384, - "qk_norm": "rms_norm", - "qk_norm_on_additional_modules": "rms_norm", - "sample_size": 128, - "perspective_modeling_type": "implicit", - "projection_class_embeddings_input_dim": 2816, - "enable_crossview": true, - "crossview_attention_type": "rowwise", - "crossview_block_layers": [ - 1, - 5, - 9, - 13, - 17, - 21 - ], - "crossview_gradient_checkpointing": true, - "enable_temporal": true, - "temporal_attention_type": "rowwise", - "temporal_block_layers": [ - 2, - 3, - 6, - 7, - 10, - 11, - 14, - 15, - 18, - 19, - 22, - 23 - ], - "temporal_gradient_checkpointing": true, - "mixer_type": "AlphaBlender", - "merge_factor": 2, - "condition_image_adapter_config": { - "in_channels": 6, - "channels": [ - 1536, - 1536, - 1536, - 1536, - 1536, - 1536 - ], - "is_downblocks": [ - true, - false, - false, - false, - false, - false - ], - "num_res_blocks": 2, - "downscale_factor": 8, - "use_zero_convs": true - } - }, - "pretrained_model_name_or_path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/models/stable-diffusion-3.5-medium", - "model_checkpoint_path": "/mnt/afs/user/wuzehuan/Tasks/ctsd_35_tirda_bm_nwa_warmup/checkpoints/5000.pth", - "model_load_state_args": { - "strict": false - }, - "metrics": { - "fid": { - "_class_name": "torchmetrics.image.fid.FrechetInceptionDistance", - "normalize": true - }, - "fvd": { - "_class_name": "dwm.metrics.fvd.FrechetVideoDistance", - "inception_3d_checkpoint_path": "/mnt/afs/user/wuzehuan/Documents/DWM/externals/TATS/tats/fvd/i3d_pretrained_400.pt", - "sequence_count": 16 - } - } - }, - "training_dataset": { - "_class_name": "dwm.datasets.common.DatasetAdapter", - "base_dataset": { - "_class_name": "torch.utils.data.ConcatDataset", - "datasets": [ - { - "_class_name": "dwm.datasets.nuscenes.MotionDataset", - "fs": { - "_class_name": "dwm.common.get_state", - "key": "nuscenes_fs" - }, - "dataset_name": "interp_12Hz_trainval", - "split": "train", - "sequence_length": 19, - "fps_stride_tuples": [ - [ - 10, - 0.1 - ] - ], - "sensor_channels": [ - "LIDAR_TOP", - "CAM_FRONT_LEFT", - "CAM_FRONT", - "CAM_FRONT_RIGHT", - "CAM_BACK_RIGHT", - "CAM_BACK", - "CAM_BACK_LEFT" - ], - "keyframe_only": true, - "enable_camera_transforms": true, - "enable_ego_transforms": true, - "_3dbox_image_settings": {}, - "hdmap_image_settings": {}, - "image_description_settings": { - "path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/workspaces/worldmodels/data/nuscenes_v1.0-trainval_caption_v2_train.json", - "time_list_dict_path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/workspaces/worldmodels/data/nuscenes_v1.0-trainval_caption_v2_times_train.json", - "align_keys": [ - "time", - "weather" - ], - "reorder_keys": true, - "drop_rates": { - "environment": 0.04, - "objects": 0.08, - "image_description": 0.16 - } - }, - "stub_key_data_dict": { - "crossview_mask": [ - "content", - { - "_class_name": "torch.tensor", - "data": { - "_class_name": "json.loads", - "s": "[[1,1,0,0,0,1],[1,1,1,0,0,0],[0,1,1,1,0,0],[0,0,1,1,1,0],[0,0,0,1,1,1],[1,0,0,0,1,1]]" - }, - "dtype": { - "_class_name": "get_class", - "class_name": "torch.bool" - } - } - ] - } - }, - { - "_class_name": "dwm.datasets.waymo.MotionDataset", - "fs": { - "_class_name": "dwm.fs.dirfs.DirFileSystem", - "path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/data/waymo/waymo_open_dataset_v_1_4_3/training" - }, - "info_dict_path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/data/waymo/waymo_open_dataset_v_1_4_3/training.info.json", - "sequence_length": 19, - "fps_stride_tuples": [ - [ - 10, - 0.1 - ] - ], - "sensor_channels": [ - "LIDAR_TOP", - "CAM_SIDE_LEFT", - "CAM_FRONT_LEFT", - "CAM_FRONT", - "CAM_FRONT_RIGHT", - "CAM_SIDE_RIGHT", - "CAM_FRONT" - ], - "enable_camera_transforms": true, - "enable_ego_transforms": true, - "_3dbox_image_settings": {}, - "hdmap_image_settings": {}, - "image_description_settings": { - "path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/workspaces/worldmodels/data/waymo_caption_v2_train.json", - "time_list_dict_path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/workspaces/worldmodels/data/waymo_caption_v2_times_train.json", - "align_keys": [ - "time", - "weather" - ], - "reorder_keys": true, - "drop_rates": { - "environment": 0.04, - "objects": 0.08, - "image_description": 0.16 - } - }, - "stub_key_data_dict": { - "crossview_mask": [ - "content", - { - "_class_name": "torch.tensor", - "data": { - "_class_name": "json.loads", - "s": "[[1,1,0,0,0,0],[1,1,1,0,0,0],[0,1,1,1,0,0],[0,0,1,1,1,0],[0,0,0,1,1,0],[0,0,0,0,0,1]]" - }, - "dtype": { - "_class_name": "get_class", - "class_name": "torch.bool" - } - } - ] - } - }, - { - "_class_name": "dwm.datasets.argoverse.MotionDataset", - "fs": { - "_class_name": "dwm.fs.ctar.CombinedTarFileSystem", - "fs": { - "_class_name": "dwm.fs.dirfs.DirFileSystem", - "path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan" - }, - "paths": [ - "data/argoverse/av2/tars/sensor/train-000.tar", - "data/argoverse/av2/tars/sensor/train-001.tar", - "data/argoverse/av2/tars/sensor/train-002.tar", - "data/argoverse/av2/tars/sensor/train-003.tar", - "data/argoverse/av2/tars/sensor/train-004.tar", - "data/argoverse/av2/tars/sensor/train-005.tar", - "data/argoverse/av2/tars/sensor/train-006.tar", - "data/argoverse/av2/tars/sensor/train-007.tar", - "data/argoverse/av2/tars/sensor/train-008.tar", - "data/argoverse/av2/tars/sensor/train-009.tar", - "data/argoverse/av2/tars/sensor/train-010.tar", - "data/argoverse/av2/tars/sensor/train-011.tar", - "data/argoverse/av2/tars/sensor/train-012.tar", - "data/argoverse/av2/tars/sensor/train-013.tar" - ], - "enable_cached_info": true - }, - "sequence_length": 19, - "fps_stride_tuples": [ - [ - 10, - 0.1 - ] - ], - "sensor_channels": [ - "lidar", - "cameras/ring_front_left", - "cameras/ring_front_right", - "cameras/ring_side_right", - "cameras/ring_rear_right", - "cameras/ring_rear_left", - "cameras/ring_side_left" - ], - "enable_camera_transforms": true, - "enable_ego_transforms": true, - "_3dbox_image_settings": {}, - "hdmap_image_settings": {}, - "image_description_settings": { - "path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/workspaces/worldmodels/data/av2_sensor_caption_v2_train.json", - "time_list_dict_path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/workspaces/worldmodels/data/av2_sensor_caption_v2_times_train.json", - "align_keys": [ - "time", - "weather" - ], - "reorder_keys": true, - "drop_rates": { - "environment": 0.04, - "objects": 0.08, - "image_description": 0.16 - } - }, - "stub_key_data_dict": { - "crossview_mask": [ - "content", - { - "_class_name": "torch.tensor", - "data": { - "_class_name": "json.loads", - "s": "[[1,1,0,0,0,1],[1,1,1,0,0,0],[0,1,1,1,0,0],[0,0,1,1,1,0],[0,0,0,1,1,1],[1,0,0,0,1,1]]" - }, - "dtype": { - "_class_name": "get_class", - "class_name": "torch.bool" - } - } - ] - } - } - ] - }, - "transform_list": [ - { - "old_key": "images", - "new_key": "vae_images", - "transform": { - "_class_name": "torchvision.transforms.Compose", - "transforms": [ - { - "_class_name": "torchvision.transforms.Resize", - "size": [ - 256, - 448 - ] - }, - { - "_class_name": "torchvision.transforms.ToTensor" - } - ] - } - }, - { - "old_key": "3dbox_images", - "new_key": "3dbox_images", - "transform": { - "_class_name": "torchvision.transforms.Compose", - "transforms": [ - { - "_class_name": "torchvision.transforms.Resize", - "size": [ - 256, - 448 - ] - }, - { - "_class_name": "torchvision.transforms.ToTensor" - } - ] - } - }, - { - "old_key": "hdmap_images", - "new_key": "hdmap_images", - "transform": { - "_class_name": "torchvision.transforms.Compose", - "transforms": [ - { - "_class_name": "torchvision.transforms.Resize", - "size": [ - 256, - 448 - ] - }, - { - "_class_name": "torchvision.transforms.ToTensor" - } - ] - } - }, - { - "old_key": "image_description", - "new_key": "clip_text", - "transform": { - "_class_name": "dwm.datasets.common.Copy" - }, - "stack": false - } - ], - "pop_list": [ - "images", - "lidar_points", - "image_description" - ] - }, - "validation_dataset": { - "_class_name": "dwm.datasets.common.DatasetAdapter", - "base_dataset": { - "_class_name": "torch.utils.data.ConcatDataset", - "datasets": [ - { - "_class_name": "dwm.datasets.nuscenes.MotionDataset", - "fs": { - "_class_name": "dwm.common.get_state", - "key": "nuscenes_fs" - }, - "dataset_name": "interp_12Hz_trainval", - "split": "val", - "sequence_length": 35, - "fps_stride_tuples": [ - [ - 10, - 20 - ] - ], - "sensor_channels": [ - "LIDAR_TOP", - "CAM_FRONT_LEFT", - "CAM_FRONT", - "CAM_FRONT_RIGHT", - "CAM_BACK_RIGHT", - "CAM_BACK", - "CAM_BACK_LEFT" - ], - "keyframe_only": true, - "enable_synchronization_check": false, - "enable_camera_transforms": true, - "enable_ego_transforms": true, - "_3dbox_image_settings": {}, - "hdmap_image_settings": {}, - "image_description_settings": { - "path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/workspaces/worldmodels/data/nuscenes_v1.0-trainval_caption_v2_val.json", - "time_list_dict_path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/workspaces/worldmodels/data/nuscenes_v1.0-trainval_caption_v2_times_val.json", - "align_keys": [ - "time", - "weather" - ] - }, - "stub_key_data_dict": { - "crossview_mask": [ - "content", - { - "_class_name": "torch.tensor", - "data": { - "_class_name": "json.loads", - "s": "[[1,1,0,0,0,1],[1,1,1,0,0,0],[0,1,1,1,0,0],[0,0,1,1,1,0],[0,0,0,1,1,1],[1,0,0,0,1,1]]" - }, - "dtype": { - "_class_name": "get_class", - "class_name": "torch.bool" - } - } - ] - } - }, - { - "_class_name": "dwm.datasets.waymo.MotionDataset", - "fs": { - "_class_name": "dwm.fs.dirfs.DirFileSystem", - "path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/data/waymo/waymo_open_dataset_v_1_4_3/validation" - }, - "info_dict_path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/data/waymo/waymo_open_dataset_v_1_4_3/validation.info.json", - "sequence_length": 35, - "fps_stride_tuples": [ - [ - 10, - 20 - ] - ], - "sensor_channels": [ - "LIDAR_TOP", - "CAM_SIDE_LEFT", - "CAM_FRONT_LEFT", - "CAM_FRONT", - "CAM_FRONT_RIGHT", - "CAM_SIDE_RIGHT", - "CAM_FRONT" - ], - "enable_camera_transforms": true, - "enable_ego_transforms": true, - "_3dbox_image_settings": {}, - "hdmap_image_settings": {}, - "image_description_settings": { - "path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/workspaces/worldmodels/data/waymo_caption_v2_val.json", - "time_list_dict_path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/workspaces/worldmodels/data/waymo_caption_v2_times_val.json", - "align_keys": [ - "time", - "weather" - ] - }, - "stub_key_data_dict": { - "crossview_mask": [ - "content", - { - "_class_name": "torch.tensor", - "data": { - "_class_name": "json.loads", - "s": "[[1,1,0,0,0,0],[1,1,1,0,0,0],[0,1,1,1,0,0],[0,0,1,1,1,0],[0,0,0,1,1,0],[0,0,0,0,0,1]]" - }, - "dtype": { - "_class_name": "get_class", - "class_name": "torch.bool" - } - } - ] - } - }, - { - "_class_name": "dwm.datasets.argoverse.MotionDataset", - "fs": { - "_class_name": "dwm.fs.ctar.CombinedTarFileSystem", - "fs": { - "_class_name": "dwm.fs.dirfs.DirFileSystem", - "path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan" - }, - "paths": [ - "data/argoverse/av2/tars/sensor/val-000.tar", - "data/argoverse/av2/tars/sensor/val-001.tar", - "data/argoverse/av2/tars/sensor/val-002.tar" - ], - "enable_cached_info": true - }, - "sequence_length": 35, - "fps_stride_tuples": [ - [ - 10, - 20 - ] - ], - "sensor_channels": [ - "lidar", - "cameras/ring_front_left", - "cameras/ring_front_right", - "cameras/ring_side_right", - "cameras/ring_rear_right", - "cameras/ring_rear_left", - "cameras/ring_side_left" - ], - "enable_camera_transforms": true, - "enable_ego_transforms": true, - "_3dbox_image_settings": {}, - "hdmap_image_settings": {}, - "image_description_settings": { - "path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/workspaces/worldmodels/data/av2_sensor_caption_v2_val.json", - "time_list_dict_path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/workspaces/worldmodels/data/av2_sensor_caption_v2_times_val.json", - "align_keys": [ - "time", - "weather" - ] - }, - "stub_key_data_dict": { - "crossview_mask": [ - "content", - { - "_class_name": "torch.tensor", - "data": { - "_class_name": "json.loads", - "s": "[[1,1,0,0,0,1],[1,1,1,0,0,0],[0,1,1,1,0,0],[0,0,1,1,1,0],[0,0,0,1,1,1],[1,0,0,0,1,1]]" - }, - "dtype": { - "_class_name": "get_class", - "class_name": "torch.bool" - } - } - ] - } - } - ] - }, - "transform_list": [ - { - "old_key": "images", - "new_key": "vae_images", - "transform": { - "_class_name": "torchvision.transforms.Compose", - "transforms": [ - { - "_class_name": "torchvision.transforms.Resize", - "size": [ - 256, - 448 - ] - }, - { - "_class_name": "torchvision.transforms.ToTensor" - } - ] - } - }, - { - "old_key": "3dbox_images", - "new_key": "3dbox_images", - "transform": { - "_class_name": "torchvision.transforms.Compose", - "transforms": [ - { - "_class_name": "torchvision.transforms.Resize", - "size": [ - 256, - 448 - ] - }, - { - "_class_name": "torchvision.transforms.ToTensor" - } - ] - } - }, - { - "old_key": "hdmap_images", - "new_key": "hdmap_images", - "transform": { - "_class_name": "torchvision.transforms.Compose", - "transforms": [ - { - "_class_name": "torchvision.transforms.Resize", - "size": [ - 256, - 448 - ] - }, - { - "_class_name": "torchvision.transforms.ToTensor" - } - ] - } - }, - { - "old_key": "image_description", - "new_key": "clip_text", - "transform": { - "_class_name": "dwm.datasets.common.Copy" - }, - "stack": false - } - ], - "pop_list": [ - "images", - "lidar_points", - "image_description" - ] - }, - "training_dataloader": { - "batch_size": 2, - "num_workers": 3, - "prefetch_factor": 3, - "collate_fn": { - "_class_name": "dwm.datasets.common.CollateFnIgnoring", - "keys": [ - "clip_text" - ] - }, - "persistent_workers": true - }, - "validation_dataloader": { - "batch_size": 1, - "num_workers": 1, - "prefetch_factor": 3, - "collate_fn": { - "_class_name": "dwm.datasets.common.CollateFnIgnoring", - "keys": [ - "clip_text" - ] - }, - "persistent_workers": true - }, - "preview_dataloader": { - "batch_size": 1, - "num_workers": 1, - "prefetch_factor": 1, - "shuffle": true, - "drop_last": true, - "collate_fn": { - "_class_name": "dwm.datasets.common.CollateFnIgnoring", - "keys": [ - "clip_text" - ] - }, - "persistent_workers": true - }, - "informations": { - "fid": -1, - "fvd": -1, - "total_batch_sizes": 64, - "steps": 30000 - } -} \ No newline at end of file diff --git a/configs/ctsd/multi_datasets/ctsd_35_tirda_bm_nwa_warmup.json b/configs/ctsd/multi_datasets/ctsd_35_tirda_bm_nwao.json similarity index 77% rename from configs/ctsd/multi_datasets/ctsd_35_tirda_bm_nwa_warmup.json rename to configs/ctsd/multi_datasets/ctsd_35_tirda_bm_nwao.json index 1a9a82b..1a63d4d 100644 --- a/configs/ctsd/multi_datasets/ctsd_35_tirda_bm_nwa_warmup.json +++ b/configs/ctsd/multi_datasets/ctsd_35_tirda_bm_nwao.json @@ -1,8 +1,8 @@ { "device": "cuda", "ddp_backend": "nccl", - "train_epochs": 1, - "generator_seed": 0, + "train_epochs": 4, + "generator_seed": 1, "data_shuffle": true, "fix_training_data_order": true, "global_state": { @@ -27,6 +27,19 @@ "data/nuscenes/nuScenes-map-expansion-v1.3.zip" ] }, + "opendv_czip_fs": { + "_class_name": "dwm.fs.czip.CombinedZipFileSystem", + "fs": { + "_class_name": "dwm.fs.dirfs.DirFileSystem", + "path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/data/opendv" + }, + "paths": [ + "opendv-youtube-10hz-720_0.zip", + "opendv-youtube-10hz-720_1.zip", + "opendv-youtube-10hz-720_2.zip", + "opendv-youtube-10hz-720_3.zip" + ] + }, "device_mesh": { "_class_name": "torch.distributed.device_mesh.init_device_mesh", "device_type": "cuda", @@ -38,7 +51,11 @@ }, "optimizer": { "_class_name": "torch.optim.AdamW", - "lr": 1.2e-4 + "lr": 6e-5, + "betas": [ + 0.9, + 0.975 + ] }, "pipeline": { "_class_name": "dwm.pipelines.ctsd.CrossviewTemporalSD", @@ -47,6 +64,7 @@ "cat_condition": true, "cond_with_action": false, "condition_on_all_frames": true, + "uncondition_image_color": 0.1255, "added_time_ids": "fps_camera_transforms", "camera_intrinsic_embedding_indices": [ 0, @@ -97,8 +115,7 @@ "_class_name": "get_class", "class_name": "torch.float16" } - }, - "use_orig_params": true + } }, "t5_fsdp_wrapper_settings": { "sharding_strategy": { @@ -126,19 +143,22 @@ "class_name": "torch.float16" } }, - "memory_efficient_batch": 16 + "memory_efficient_batch": 12 }, "training_config": { - "freezing_pattern": "^(transformer_blocks|time_text_embed)$", "text_prompt_condition_ratio": 0.8, "3dbox_condition_ratio": 0.8, "hdmap_condition_ratio": 0.8, - "reference_frame_count": 3, + "reference_frame_count": { + "1": 0.1, + "2": 0.3, + "3": 0.6 + }, "generation_task_ratio": 0.25, "image_generation_ratio": 0.3, "all_reference_visible_ratio": 1, - "reference_frame_scale_std": 0.02, - "reference_frame_offset_std": 0.02, + "reference_frame_scale_std": 0.01, + "reference_frame_offset_std": 0.01, "enable_grad_scaler": true }, "inference_config": { @@ -153,7 +173,7 @@ "autoregression_data_exception_for_take_sequence": [ "crossview_mask" ], - "evaluation_item_count": 480 + "evaluation_item_count": 640 }, "model": { "_class_name": "dwm.models.crossview_temporal_dit.DiTCrossviewTemporalConditionModel", @@ -199,7 +219,7 @@ ], "crossview_gradient_checkpointing": true, "enable_temporal": true, - "temporal_attention_type": "rowwise", + "temporal_attention_type": "pointwise", "temporal_block_layers": [ 2, 3, @@ -456,6 +476,95 @@ } ] } + }, + { + "_class_name": "dwm.datasets.opendv.MotionDataset", + "fs": { + "_class_name": "dwm.common.get_state", + "key": "opendv_czip_fs" + }, + "meta_path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/nijingcheng/datasets/OpenDV-YouTube.json", + "sequence_length": 19, + "fps_stride_tuples": [ + [ + 10, + 5 + ] + ], + "split": "Train", + "mini_batch": 6, + "ignore_list": [ + "izhGt1GnGFk" + ], + "enable_pts": false, + "enable_fake_camera_transforms": true, + "enable_fake_3dbox_images": true, + "enable_fake_hdmap_images": true, + "fake_condition_image_color": [ + 32, + 32, + 32 + ], + "image_description_settings": { + "path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/workspaces/worldmodels/data/opendv_caption.json", + "candidates_times_path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/workspaces/worldmodels/data/opendv_candidates_times.json", + "seed": 5, + "reorder_keys": true, + "drop_rates": { + "environment": 0.04, + "objects": 0.08, + "image_description": 0.16 + } + }, + "stub_key_data_dict": { + "pts": [ + "content", + { + "_class_name": "torch.zeros", + "size": [ + 19, + 7 + ] + } + ], + "crossview_mask": [ + "content", + { + "_class_name": "torch.eye", + "n": 6, + "dtype": { + "_class_name": "get_class", + "class_name": "torch.bool" + } + } + ], + "ego_transforms": [ + "content", + { + "_class_name": "torch.repeat_interleave", + "input": { + "_class_name": "torch.repeat_interleave", + "input": { + "_class_name": "torch.reshape", + "input": { + "_class_name": "torch.eye", + "n": 4 + }, + "shape": [ + 1, + 1, + 4, + 4 + ] + }, + "repeats": 7, + "dim": 1 + }, + "repeats": 19, + "dim": 0 + } + ] + } } ] }, @@ -529,6 +638,7 @@ "pop_list": [ "images", "lidar_points", + "lidar_transforms", "image_description" ] }, @@ -703,6 +813,85 @@ } ] } + }, + { + "_class_name": "dwm.datasets.opendv.MotionDataset", + "fs": { + "_class_name": "dwm.common.get_state", + "key": "opendv_czip_fs" + }, + "meta_path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/nijingcheng/datasets/OpenDV-YouTube.json", + "sequence_length": 35, + "fps_stride_tuples": [ + [ + 10, + 75 + ] + ], + "split": "Val", + "mini_batch": 6, + "enable_pts": false, + "enable_fake_camera_transforms": true, + "enable_fake_3dbox_images": true, + "enable_fake_hdmap_images": true, + "fake_condition_image_color": [ + 32, + 32, + 32 + ], + "image_description_settings": { + "path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/workspaces/worldmodels/data/opendv_caption.json", + "candidates_times_path": "/cache/aoss.cn-sh-01.sensecoreapi-oss.cn/users/wuzehuan/workspaces/worldmodels/data/opendv_candidates_times.json" + }, + "stub_key_data_dict": { + "pts": [ + "content", + { + "_class_name": "torch.zeros", + "size": [ + 35, + 7 + ] + } + ], + "crossview_mask": [ + "content", + { + "_class_name": "torch.eye", + "n": 6, + "dtype": { + "_class_name": "get_class", + "class_name": "torch.bool" + } + } + ], + "ego_transforms": [ + "content", + { + "_class_name": "torch.repeat_interleave", + "input": { + "_class_name": "torch.repeat_interleave", + "input": { + "_class_name": "torch.reshape", + "input": { + "_class_name": "torch.eye", + "n": 4 + }, + "shape": [ + 1, + 1, + 4, + 4 + ] + }, + "repeats": 7, + "dim": 1 + }, + "repeats": 35, + "dim": 0 + } + ] + } } ] }, @@ -776,11 +965,43 @@ "pop_list": [ "images", "lidar_points", + "lidar_transforms", "image_description" ] }, + "mix_config": { + "256-448": [ + 0.6, + [ + [ + 19, + 2, + 1.0 + ] + ] + ], + "176-304": [ + 0.3, + [ + [ + 19, + 4, + 1.0 + ] + ] + ], + "144-256": [ + 0.1, + [ + [ + 19, + 6, + 1.0 + ] + ] + ] + }, "training_dataloader": { - "batch_size": 2, "num_workers": 3, "prefetch_factor": 3, "collate_fn": { @@ -818,9 +1039,12 @@ "persistent_workers": true }, "informations": { - "fid": 14.75, - "fvd": 268.69, - "total_batch_sizes": 64, - "steps": 5000 + "fid": 9.46, + "fvd": 91.55, + "fvd_on_nusc_by_3_ref_frames": 38.83, + "fvd_on_nusc_by_1_ref_frames": 45.98, + "fvd_on_nusc_without_ref_frame": 117.50, + "average_total_batch_sizes": 96, + "steps": 40000 } } \ No newline at end of file diff --git a/examples/ctsd_35_6views_video_generation_with_layout.json b/examples/ctsd_35_6views_video_generation_with_layout.json new file mode 100644 index 0000000..bb0e83d --- /dev/null +++ b/examples/ctsd_35_6views_video_generation_with_layout.json @@ -0,0 +1,292 @@ +{ + "device": "cuda", + "generator_seed": 0, + "pipeline": { + "_class_name": "dwm.pipelines.ctsd.CrossviewTemporalSD", + "common_config": { + "frame_prediction_style": "ctsd", + "condition_on_all_frames": true, + "uncondition_image_color": 0.1255, + "added_time_ids": "fps_camera_transforms", + "camera_intrinsic_embedding_indices": [ + 0, + 4, + 2, + 5 + ], + "camera_intrinsic_denom_embedding_indices": [ + 1, + 1, + 0, + 1 + ], + "camera_transform_embedding_indices": [ + 2, + 6, + 10, + 3, + 7, + 11 + ], + "autocast": { + "device_type": "cuda" + }, + "text_encoder_load_args": { + "variant": "fp16", + "torch_dtype": { + "_class_name": "get_class", + "class_name": "torch.float16" + }, + "quantization_config": { + "_class_name": "diffusers.quantizers.quantization_config.BitsAndBytesConfig", + "load_in_4bit": true, + "bnb_4bit_quant_type": "nf4", + "bnb_4bit_compute_dtype": { + "_class_name": "get_class", + "class_name": "torch.float16" + } + } + }, + "memory_efficient_batch": 12 + }, + "training_config": {}, + "inference_config": { + "guidance_scale": 4, + "inference_steps": 40, + "preview_image_size": [ + 448, + 252 + ], + "sequence_length_per_iteration": 19, + "reference_frame_count": 3, + "autoregression_data_exception_for_take_sequence": [ + "crossview_mask" + ] + }, + "model": { + "_class_name": "dwm.models.crossview_temporal_dit.DiTCrossviewTemporalConditionModel", + "dual_attention_layers": [ + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12 + ], + "attention_head_dim": 64, + "caption_projection_dim": 1536, + "in_channels": 16, + "joint_attention_dim": 4096, + "num_attention_heads": 24, + "num_layers": 24, + "out_channels": 16, + "patch_size": 2, + "pooled_projection_dim": 2048, + "pos_embed_max_size": 384, + "qk_norm": "rms_norm", + "qk_norm_on_additional_modules": "rms_norm", + "sample_size": 128, + "perspective_modeling_type": "implicit", + "projection_class_embeddings_input_dim": 2816, + "enable_crossview": true, + "crossview_attention_type": "rowwise", + "crossview_block_layers": [ + 1, + 5, + 9, + 13, + 17, + 21 + ], + "crossview_gradient_checkpointing": true, + "enable_temporal": true, + "temporal_attention_type": "pointwise", + "temporal_block_layers": [ + 2, + 3, + 6, + 7, + 10, + 11, + 14, + 15, + 18, + 19, + 22, + 23 + ], + "temporal_gradient_checkpointing": true, + "mixer_type": "AlphaBlender", + "merge_factor": 2, + "condition_image_adapter_config": { + "in_channels": 6, + "channels": [ + 1536, + 1536, + 1536, + 1536, + 1536, + 1536 + ], + "is_downblocks": [ + true, + false, + false, + false, + false, + false + ], + "num_res_blocks": 2, + "downscale_factor": 8, + "use_zero_convs": true + } + }, + "model_dtype": { + "_class_name": "get_class", + "class_name": "torch.float16" + }, + "pretrained_model_name_or_path": "/mnt/afs/user/wuzehuan/Downloads/models/stable-diffusion-3.5-medium", + "model_checkpoint_path": "/mnt/afs/user/wuzehuan/Tasks/ctsd_35_tirda_bm_nwao/checkpoints/40000.pth" + }, + "validation_dataset": { + "_class_name": "dwm.datasets.common.DatasetAdapter", + "base_dataset": { + "_class_name": "dwm.datasets.preview.PreviewDataset", + "json_file": "/mnt/afs/user/wuzehuan/Documents/DWM/output/carla_town04_package/data.json", + "sequence_length": 179, + "fps_stride_tuples": [ + [ + 10, + 60 + ] + ], + "sensor_channels": [ + "CAM_FRONT_LEFT", + "CAM_FRONT", + "CAM_FRONT_RIGHT", + "CAM_BACK_RIGHT", + "CAM_BACK", + "CAM_BACK_LEFT" + ], + "enable_camera_transforms": true, + "use_hdmap": true, + "use_3dbox": true, + "drop_vehicle_color": true, + "stub_key_data_dict": { + "crossview_mask": [ + "content", + { + "_class_name": "torch.tensor", + "data": { + "_class_name": "json.loads", + "s": "[[1,1,0,0,0,1],[1,1,1,0,0,0],[0,1,1,1,0,0],[0,0,1,1,1,0],[0,0,0,1,1,1],[1,0,0,0,1,1]]" + }, + "dtype": { + "_class_name": "get_class", + "class_name": "torch.bool" + } + } + ] + } + }, + "transform_list": [ + { + "old_key": "images", + "new_key": "vae_images", + "transform": { + "_class_name": "torchvision.transforms.Compose", + "transforms": [ + { + "_class_name": "torchvision.transforms.Resize", + "size": [ + 256, + 448 + ] + }, + { + "_class_name": "torchvision.transforms.ToTensor" + } + ] + } + }, + { + "old_key": "3dbox_images", + "new_key": "3dbox_images", + "transform": { + "_class_name": "torchvision.transforms.Compose", + "transforms": [ + { + "_class_name": "torchvision.transforms.Resize", + "size": [ + 256, + 448 + ] + }, + { + "_class_name": "torchvision.transforms.ToTensor" + } + ] + } + }, + { + "old_key": "hdmap_images", + "new_key": "hdmap_images", + "transform": { + "_class_name": "torchvision.transforms.Compose", + "transforms": [ + { + "_class_name": "torchvision.transforms.Resize", + "size": [ + 256, + 448 + ] + }, + { + "_class_name": "torchvision.transforms.ToTensor" + } + ] + } + }, + { + "old_key": "image_description", + "new_key": "clip_text", + "transform": { + "_class_name": "dwm.datasets.common.Copy" + }, + "stack": false + } + ], + "pop_list": [ + "images", + "pred_images", + "image_description" + ] + }, + "validation_dataloader": { + "batch_size": 1, + "num_workers": 1, + "collate_fn": { + "_class_name": "dwm.datasets.common.CollateFnIgnoring", + "keys": [ + "clip_text" + ] + } + }, + "preview_dataloader": { + "batch_size": 1, + "num_workers": 1, + "collate_fn": { + "_class_name": "dwm.datasets.common.CollateFnIgnoring", + "keys": [ + "clip_text" + ] + } + } +} \ No newline at end of file diff --git a/src/dwm/datasets/common.py b/src/dwm/datasets/common.py index b3c6872..8287dc6 100644 --- a/src/dwm/datasets/common.py +++ b/src/dwm/datasets/common.py @@ -51,9 +51,6 @@ def apply_transform(transform, a, stack: bool = True): else: return transform(a) - def apply_temporal_transform(transform, a): - return transform(a) - def __init__( self, base_dataset: torch.utils.data.Dataset, transform_list: list, pop_list=None @@ -70,9 +67,8 @@ def __getitem__(self, index): if isinstance(index, int): item = self.base_dataset[index] for i in self.transform_list: - if getattr(i["transform"], 'is_temporal_transform', False): - item[i["new_key"]] = DatasetAdapter.apply_temporal_transform( - i["transform"], item[i["old_key"]]) + if i.get("is_dynamic_transform", False): + item = i["transform"](item) else: item[i["new_key"]] = DatasetAdapter.apply_transform( i["transform"], item[i["old_key"]], @@ -313,7 +309,6 @@ def make_image_description_string(caption_dict: dict, settings: dict): probabilities of the corresponding key elements in the caption_dict being dropped. """ - default_image_description_keys = [ "time", "weather", "environment", "objects", "image_description" ] diff --git a/src/dwm/datasets/nuscenes.py b/src/dwm/datasets/nuscenes.py index 5ba506d..df32511 100755 --- a/src/dwm/datasets/nuscenes.py +++ b/src/dwm/datasets/nuscenes.py @@ -732,7 +732,7 @@ def get_hdmap_bev_image( MotionDataset.default_bev_from_ego_transform) color_table = hdmap_bev_settings.get( "color_table", MotionDataset.default_hdmap_color_table) - + fill_map = hdmap_bev_settings.get("fill_map", True) # get the transform from the world (map) space to the BEV space world_from_ego = dwm.datasets.common.get_transform( sample_data["rotation"], sample_data["translation"]) @@ -759,14 +759,14 @@ def get_hdmap_bev_image( for polygon_token in i["polygon_tokens"]: MotionDataset.draw_polygon_to_bev_image( polygons[polygon_token], nodes, draw, bev_from_world, - (0, 0, 255), pen_width, solid=True) + (0, 0, 255), pen_width, solid=fill_map) if "ped_crossing" in color_table and "ped_crossing" in map: pen_color = tuple(color_table["ped_crossing"]) for i in map["ped_crossing"]: MotionDataset.draw_polygon_to_bev_image( polygons[i["polygon_token"]], nodes, draw, bev_from_world, - (255, 0, 0), pen_width, solid=True) + (255, 0, 0), pen_width, solid=fill_map) if "lane" in color_table and "lane" in map: pen_color = tuple(color_table["lane"]) diff --git a/src/dwm/fs/s3fs.py b/src/dwm/fs/s3fs.py index 0ab8aec..74afd5b 100644 --- a/src/dwm/fs/s3fs.py +++ b/src/dwm/fs/s3fs.py @@ -138,7 +138,7 @@ def _open( def ls(self, path, detail=True, **kwargs): self.reinit_if_forked() bucket, key = S3File.find_bucket_key(path) - if not key.endswith("/"): + if len(key) > 0 and not key.endswith("/"): key = key + "/" # NOTE: only files are listed @@ -151,7 +151,7 @@ def ls(self, path, detail=True, **kwargs): else: response = self.client.list_objects( Bucket=bucket, Delimiter="/", Prefix=key, - ContinuationToken=continuation_token) + Marker=continuation_token) if "Contents" in response: for i in response["Contents"]: @@ -165,7 +165,7 @@ def ls(self, path, detail=True, **kwargs): }) if response["IsTruncated"]: - continuation_token = response["NextContinuationToken"] + continuation_token = response["NextMarker"] else: break diff --git a/src/dwm/models/crossview_temporal_dit.py b/src/dwm/models/crossview_temporal_dit.py index d0d1e50..06e223f 100644 --- a/src/dwm/models/crossview_temporal_dit.py +++ b/src/dwm/models/crossview_temporal_dit.py @@ -350,13 +350,13 @@ def forward_temporal_block_and_mix_result( b=batch_size, v=view_count, w=width) else: temporal_hidden_states = einops.rearrange( - temporal_hidden_states, "(b t v) hw c -> (bv hw) t c", + temporal_hidden_states, "(b t v) hw c -> (b v hw) t c", b=batch_size, t=sequence_length) temporal_hidden_states = temporal_block( temporal_hidden_states) temporal_hidden_states = einops.rearrange( - temporal_hidden_states, "(bv hw) t c -> (b t v) hw c", - b=batch_size, t=sequence_length) + temporal_hidden_states, "(b v hw) t c -> (b t v) hw c", + b=batch_size, v=view_count, t=sequence_length) return mixer( hidden_states.view(