-
-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0d09830
commit 09b5ff9
Showing
2 changed files
with
309 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
# dynamic_encoder.py | ||
|
||
""" Dynamic fusion encoder implementation for multimodal learning """ | ||
|
||
|
||
from typing import Dict, Optional, List, Union | ||
import torch | ||
from torch import nn | ||
|
||
from pvnet.models.multimodal.encoders.basic_blocks import AbstractNWPSatelliteEncoder | ||
from pvnet.models.multimodal.fusion_blocks import DynamicFusionModule, ModalityGating | ||
from pvnet.models.multimodal.attention_blocks import CrossModalAttention, SelfAttention | ||
from pvnet.models.multimodal.encoders.encoders3d import DefaultPVNet2 | ||
|
||
|
||
class PVEncoder(nn.Module): | ||
""" Simplified PV encoder - maintains sequence dimension """ | ||
|
||
def __init__(self, sequence_length: int, num_sites: int, out_features: int): | ||
super().__init__() | ||
self.sequence_length = sequence_length | ||
self.num_sites = num_sites | ||
self.out_features = out_features | ||
|
||
# Process each timestep independently | ||
self.encoder = nn.Sequential( | ||
nn.Linear(num_sites, out_features), | ||
nn.LayerNorm(out_features), | ||
nn.ReLU(), | ||
nn.Dropout(0.1) | ||
) | ||
|
||
def forward(self, x): | ||
# x: [batch_size, sequence_length, num_sites] | ||
batch_size = x.shape[0] | ||
# Process each timestep | ||
out = [] | ||
for t in range(self.sequence_length): | ||
out.append(self.encoder(x[:, t])) | ||
# Stack along sequence dimension | ||
return torch.stack(out, dim=1) # [batch_size, sequence_length, out_features] | ||
|
||
|
||
class DynamicFusionEncoder(AbstractNWPSatelliteEncoder): | ||
|
||
"""Encoder that implements dynamic fusion of satellite/NWP data streams""" | ||
|
||
def __init__( | ||
self, | ||
sequence_length: int, | ||
image_size_pixels: int, | ||
modality_channels: Dict[str, int], | ||
out_features: int, | ||
modality_encoders: Dict[str, dict], | ||
cross_attention: Dict, | ||
modality_gating: Dict, | ||
dynamic_fusion: Dict, | ||
hidden_dim: int = 256, | ||
fc_features: int = 128, | ||
num_heads: int = 8, | ||
dropout: float = 0.1, | ||
use_gating: bool = True, | ||
use_cross_attention: bool = True | ||
): | ||
"""Dynamic fusion encoder for multimodal satellite/NWP data.""" | ||
super().__init__( | ||
sequence_length=sequence_length, | ||
image_size_pixels=image_size_pixels, | ||
in_channels=sum(modality_channels.values()), | ||
out_features=out_features | ||
) | ||
|
||
self.modalities = list(modality_channels.keys()) | ||
self.hidden_dim = hidden_dim | ||
self.sequence_length = sequence_length | ||
|
||
# Initialize modality-specific encoders | ||
self.modality_encoders = nn.ModuleDict() | ||
for modality, config in modality_encoders.items(): | ||
config = config.copy() | ||
if 'nwp' in modality or 'sat' in modality: | ||
encoder = DefaultPVNet2( | ||
sequence_length=sequence_length, | ||
image_size_pixels=config.get('image_size_pixels', image_size_pixels), | ||
in_channels=modality_channels[modality], | ||
out_features=config.get('out_features', hidden_dim), | ||
number_of_conv3d_layers=config.get('number_of_conv3d_layers', 4), | ||
conv3d_channels=config.get('conv3d_channels', 32), | ||
batch_norm=config.get('batch_norm', True), | ||
fc_dropout=config.get('fc_dropout', 0.2) | ||
) | ||
|
||
self.modality_encoders[modality] = nn.Sequential( | ||
encoder, | ||
nn.Unflatten(1, (sequence_length, hidden_dim//sequence_length)) | ||
) | ||
|
||
elif modality == 'pv': | ||
self.modality_encoders[modality] = PVEncoder( | ||
sequence_length=sequence_length, | ||
num_sites=config['num_sites'], | ||
out_features=hidden_dim | ||
) | ||
|
||
# Feature projections | ||
self.feature_projections = nn.ModuleDict({ | ||
modality: nn.Sequential( | ||
nn.Linear(hidden_dim, hidden_dim), | ||
nn.LayerNorm(hidden_dim), | ||
nn.ReLU(), | ||
nn.Dropout(dropout) | ||
) | ||
for modality in modality_channels.keys() | ||
}) | ||
|
||
# Optional modality gating | ||
self.use_gating = use_gating | ||
if use_gating: | ||
gating_config = modality_gating.copy() | ||
gating_config['feature_dims'] = { | ||
mod: hidden_dim for mod in modality_channels.keys() | ||
} | ||
self.gating = ModalityGating(**gating_config) | ||
|
||
# Optional cross-modal attention | ||
self.use_cross_attention = use_cross_attention | ||
if use_cross_attention: | ||
attention_config = cross_attention.copy() | ||
attention_config['embed_dim'] = hidden_dim | ||
self.cross_attention = CrossModalAttention(**attention_config) | ||
|
||
# Dynamic fusion module | ||
fusion_config = dynamic_fusion.copy() | ||
fusion_config['feature_dims'] = { | ||
mod: hidden_dim for mod in modality_channels.keys() | ||
} | ||
fusion_config['hidden_dim'] = hidden_dim | ||
self.fusion_module = DynamicFusionModule(**fusion_config) | ||
|
||
# Final output projection | ||
self.final_block = nn.Sequential( | ||
nn.Linear(hidden_dim * sequence_length, fc_features), | ||
nn.ELU(), | ||
nn.Linear(fc_features, out_features), | ||
nn.ELU(), | ||
) | ||
|
||
def forward( | ||
self, | ||
inputs: Dict[str, torch.Tensor], | ||
mask: Optional[torch.Tensor] = None | ||
) -> torch.Tensor: | ||
"""Forward pass of the dynamic fusion encoder""" | ||
# Initial encoding of each modality | ||
encoded_features = {} | ||
for modality, x in inputs.items(): | ||
if modality not in self.modality_encoders: | ||
continue | ||
|
||
# Apply modality-specific encoder | ||
# Output shape: [batch_size, sequence_length, hidden_dim] | ||
encoded_features[modality] = self.modality_encoders[modality](x) | ||
|
||
if not encoded_features: | ||
raise ValueError("No valid features found in inputs") | ||
|
||
# Apply modality gating if enabled | ||
if self.use_gating: | ||
encoded_features = self.gating(encoded_features) | ||
|
||
# Apply cross-modal attention if enabled and more than one modality | ||
if self.use_cross_attention and len(encoded_features) > 1: | ||
encoded_features = self.cross_attention(encoded_features, mask) | ||
|
||
# Apply dynamic fusion | ||
fused_features = self.fusion_module(encoded_features, mask) # [batch, sequence, hidden] | ||
|
||
# Reshape and apply final projection | ||
batch_size = fused_features.size(0) | ||
fused_features = fused_features.reshape(batch_size, -1) # Flatten sequence dimension | ||
output = self.final_block(fused_features) | ||
|
||
return output | ||
|
||
|
||
class DynamicResidualEncoder(DynamicFusionEncoder): | ||
"""Dynamic fusion encoder with residual connections""" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
# Override feature projections to include residual connections | ||
self.feature_projections = nn.ModuleDict({ | ||
modality: nn.Sequential( | ||
nn.Linear(self.hidden_dim, self.hidden_dim), | ||
nn.LayerNorm(self.hidden_dim), | ||
nn.ReLU(), | ||
nn.Dropout(kwargs.get('dropout', 0.1)), | ||
nn.Linear(self.hidden_dim, self.hidden_dim), | ||
nn.LayerNorm(self.hidden_dim) | ||
) | ||
for modality in kwargs['modality_channels'].keys() | ||
}) |
106 changes: 106 additions & 0 deletions
106
tests/models/multimodal/encoders/test_dynamic_encoder.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import pytest | ||
import torch | ||
from typing import Dict | ||
|
||
from pvnet.models.multimodal.encoders.dynamic_encoder import DynamicFusionEncoder | ||
|
||
@pytest.fixture | ||
def minimal_config(): | ||
"""Minimal configuration for testing basic functionality""" | ||
sequence_length = 12 | ||
hidden_dim = 60 # Chosen so it divides evenly by sequence_length (60/12 = 5) | ||
|
||
# Important: feature_dim needs to match between modalities | ||
feature_dim = hidden_dim // sequence_length # This is 5 | ||
|
||
return { | ||
'sequence_length': sequence_length, | ||
'image_size_pixels': 24, | ||
'modality_channels': { | ||
'sat': 2, | ||
'pv': 10 | ||
}, | ||
'out_features': 32, | ||
'hidden_dim': hidden_dim, | ||
'fc_features': 32, | ||
'modality_encoders': { | ||
'sat': { | ||
'image_size_pixels': 24, | ||
'out_features': feature_dim * sequence_length, # 60 | ||
'number_of_conv3d_layers': 2, | ||
'conv3d_channels': 16, | ||
'batch_norm': True, | ||
'fc_dropout': 0.1 | ||
}, | ||
'pv': { | ||
'num_sites': 10, | ||
'out_features': feature_dim # 5 - this ensures proper dimension | ||
} | ||
}, | ||
'cross_attention': { | ||
'embed_dim': hidden_dim, | ||
'num_heads': 4, | ||
'dropout': 0.1, | ||
'num_modalities': 2 | ||
}, | ||
'modality_gating': { | ||
'feature_dims': { | ||
'sat': hidden_dim, | ||
'pv': hidden_dim | ||
}, | ||
'hidden_dim': hidden_dim, | ||
'dropout': 0.1 | ||
}, | ||
'dynamic_fusion': { | ||
'feature_dims': { | ||
'sat': hidden_dim, | ||
'pv': hidden_dim | ||
}, | ||
'hidden_dim': hidden_dim, | ||
'num_heads': 4, | ||
'dropout': 0.1, | ||
'fusion_method': 'weighted_sum', | ||
'use_residual': True | ||
} | ||
} | ||
|
||
@pytest.fixture | ||
def minimal_inputs(minimal_config): | ||
"""Generate minimal test inputs""" | ||
batch_size = 2 | ||
sequence_length = minimal_config['sequence_length'] | ||
|
||
return { | ||
'sat': torch.randn(batch_size, 2, sequence_length, 24, 24), | ||
'pv': torch.randn(batch_size, sequence_length, 10) | ||
} | ||
|
||
def test_batch_sizes(self, minimal_config, minimal_inputs, batch_size): | ||
"""Test different batch sizes""" | ||
encoder = DynamicFusionEncoder( | ||
sequence_length=minimal_config['sequence_length'], | ||
image_size_pixels=minimal_config['image_size_pixels'], | ||
modality_channels=minimal_config['modality_channels'], | ||
out_features=minimal_config['out_features'], | ||
modality_encoders=minimal_config['modality_encoders'], | ||
cross_attention=minimal_config['cross_attention'], | ||
modality_gating=minimal_config['modality_gating'], | ||
dynamic_fusion=minimal_config['dynamic_fusion'], | ||
hidden_dim=minimal_config['hidden_dim'], | ||
fc_features=minimal_config['fc_features'] | ||
) | ||
|
||
# Adjust input batch sizes - fixed repeat logic | ||
adjusted_inputs = {} | ||
for k, v in minimal_inputs.items(): | ||
if batch_size < v.size(0): | ||
adjusted_inputs[k] = v[:batch_size] | ||
else: | ||
repeat_factor = batch_size // v.size(0) | ||
adjusted_inputs[k] = v.repeat(repeat_factor, *[1]*(len(v.shape)-1)) | ||
|
||
with torch.no_grad(): | ||
output = encoder(adjusted_inputs) | ||
|
||
assert output.shape == (batch_size, minimal_config['out_features']) | ||
assert not torch.isnan(output).any() |