diff --git a/chai_lab/chai1.py b/chai_lab/chai1.py index e09c1c8..7373942 100644 --- a/chai_lab/chai1.py +++ b/chai_lab/chai1.py @@ -402,9 +402,7 @@ def make_all_atom_feature_context( # Load templates if templates_path is None: - assert ( - not use_templates_server - ), "Templates path should never be none when querying server for templates" + assert not use_templates_server, "Server should have written a path" template_context = TemplateContext.empty( n_tokens=n_actual_tokens, n_templates=MAX_NUM_TEMPLATES, diff --git a/chai_lab/data/dataset/msas/msa_context.py b/chai_lab/data/dataset/msas/msa_context.py index ce94c43..e17392b 100644 --- a/chai_lab/data/dataset/msas/msa_context.py +++ b/chai_lab/data/dataset/msas/msa_context.py @@ -52,7 +52,7 @@ def __getitem__(self, subscript: tuple) -> "MSAContext": mask=self.mask[subscript], ) - def take_rows_with_padding(self, row_indices_with_nones: list): + def take_rows_with_padding(self, row_indices_with_nones: list[int | None]): """ allows specifying index=None, which will be filled with empty sequence, helpful to align multiple sequences diff --git a/chai_lab/data/dataset/msas/utils.py b/chai_lab/data/dataset/msas/utils.py index fe05468..b61d023 100644 --- a/chai_lab/data/dataset/msas/utils.py +++ b/chai_lab/data/dataset/msas/utils.py @@ -3,7 +3,7 @@ # See the LICENSE file for details. import torch -from einops import rearrange, repeat +from einops import rearrange, reduce, repeat from torch import Tensor from chai_lab.utils.typing import Bool, typecheck @@ -18,20 +18,27 @@ def subsample_msa_rows( """Adjust masking to look at a random subset of msas. Returns input mask as-is if select_n_rows <= 0 or depth < select_n_rows.""" - nonnull_rows_mask = rearrange(mask.any(dim=-1), "1 d -> d") + # Count the number of non-padding residues in each row of the MSA + msa_sizes = rearrange( + reduce(mask, "b depth tok -> b depth", reduction="sum"), "1 depth -> depth" + ) + nonnull_rows_mask = msa_sizes > 0 input_depth = nonnull_rows_mask.sum().item() if select_n_rows <= 0 or input_depth <= select_n_rows: return mask - # Select from rows of the MSA that are not fully masked out - (nonnull_row_indices,) = torch.where(nonnull_rows_mask) - assert (n := nonnull_row_indices.numel()) > select_n_rows - permuted = torch.randperm(n, device=mask.device, generator=generator) - selected_row_indices = nonnull_row_indices[permuted[:select_n_rows]] + # Bias towards bigger hit MSAs; 0 size is automatically nulled out + mask_ranking = msa_sizes * torch.rand( + size=msa_sizes.shape, + dtype=torch.float16, + device=msa_sizes.device, + generator=generator, + ) + # Ascending sort -> choose the last (highest scoring) rows + selected_row_indices = mask_ranking.argsort()[-select_n_rows:] # Create a mask for selected row indices selection_mask = torch.zeros_like(nonnull_rows_mask) selection_mask[selected_row_indices] = True - selection_mask = repeat(selection_mask, "d -> 1 d 1") - return mask & selection_mask + return mask & repeat(selection_mask, "d -> 1 d 1")