diff --git a/graphium/nn/architectures/global_architectures.py b/graphium/nn/architectures/global_architectures.py index 2b2ecc1c2..1b6f44dd9 100644 --- a/graphium/nn/architectures/global_architectures.py +++ b/graphium/nn/architectures/global_architectures.py @@ -666,8 +666,9 @@ def _parse_subset_in_dim( """ # Parse the subset_in_dim, make sure value is between 0 and 1 + subset_idx = None if subset_in_dim is None: - subset_in_dim = 1.0 + return 1.0, None if isinstance(subset_in_dim, int): assert ( subset_in_dim > 0 and subset_in_dim <= in_dim @@ -681,9 +682,7 @@ def _parse_subset_in_dim( subset_in_dim = 1 # Create the subset_idx, which is a list of indices to use for each ensemble - if subset_in_dim == in_dim: - subset_idx = None - else: + if subset_in_dim != in_dim: subset_idx = torch.stack([torch.randperm(in_dim)[:subset_in_dim] for _ in range(num_ensemble)]) return subset_in_dim, subset_idx @@ -719,7 +718,9 @@ def forward(self, h: torch.Tensor) -> torch.Tensor: # Subset the input features for each MLP in the ensemble if self.subset_idx is not None: if len(h.shape) != 2: - assert h.shape[-3] == 1, f"Expected shape to be [B, Din] or [..., 1, B, Din], got {h.shape}." + assert ( + h.shape[-3] == 1 + ), f"Expected shape to be [B, Din] or [..., 1, B, Din] when using `subset_in_dim`, got {h.shape}." h = h[..., self.subset_idx].transpose(-2, -3) # Run the standard forward pass