Skip to content

Commit

Permalink
Forgot to push this change
Browse files Browse the repository at this point in the history
  • Loading branch information
OpheliaMiralles committed Oct 18, 2024
1 parent 041d64e commit 6d4e127
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions src/anemoi/training/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,8 +239,12 @@ def __iter__(self) -> torch.Tensor:
x = self.data[start : end : self.timeincrement]
if self.spatial_index is not None:
x = x[..., self.spatial_index]
x = rearrange(x, "dates variables ensemble gridpoints -> dates ensemble gridpoints variables")
self.ensemble_dim = 1
if self.label == "predict":
x = rearrange(x, "dates variables ensemble gridpoints -> ensemble dates gridpoints variables")
self.ensemble_dim = 0
else:
x = rearrange(x, "dates variables ensemble gridpoints -> dates ensemble gridpoints variables")
self.ensemble_dim = 1

yield torch.from_numpy(x)

Expand Down

0 comments on commit 6d4e127

Please sign in to comment.