From 172b6665f0f9e625c5dead7c98e06b0bd7e99e7c Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Sun, 16 Feb 2025 15:55:16 -0800 Subject: [PATCH] Even more improvements to MSA subsampling (#306) --- chai_lab/chai1.py | 22 +++++++++-- chai_lab/data/dataset/msas/utils.py | 58 +++++++++++++++++++++++++---- 2 files changed, 69 insertions(+), 11 deletions(-) diff --git a/chai_lab/chai1.py b/chai_lab/chai1.py index 7373942..305d0a3 100644 --- a/chai_lab/chai1.py +++ b/chai_lab/chai1.py @@ -35,7 +35,9 @@ from chai_lab.data.dataset.msas.colabfold import generate_colabfold_msas from chai_lab.data.dataset.msas.load import get_msa_contexts from chai_lab.data.dataset.msas.msa_context import MSAContext -from chai_lab.data.dataset.msas.utils import subsample_msa_rows +from chai_lab.data.dataset.msas.utils import ( + subsample_and_reorder_msa_feats_n_mask, +) from chai_lab.data.dataset.structure.all_atom_structure_context import ( AllAtomStructureContext, ) @@ -705,14 +707,28 @@ def run_folding_on_context( token_single_trunk_repr = token_single_initial_repr token_pair_trunk_repr = token_pair_initial_repr for _ in tqdm(range(num_trunk_recycles), desc="Trunk recycles"): + subsampled_msa_input_feats, subsampled_msa_mask = None, None + if recycle_msa_subsample > 0: + subsampled_msa_input_feats, subsampled_msa_mask = ( + subsample_and_reorder_msa_feats_n_mask( + msa_input_feats, + msa_mask, + ) + ) (token_single_trunk_repr, token_pair_trunk_repr) = trunk.forward( move_to_device=device, token_single_trunk_initial_repr=token_single_initial_repr, token_pair_trunk_initial_repr=token_pair_initial_repr, token_single_trunk_repr=token_single_trunk_repr, # recycled token_pair_trunk_repr=token_pair_trunk_repr, # recycled - msa_input_feats=msa_input_feats, - msa_mask=subsample_msa_rows(msa_mask, select_n_rows=recycle_msa_subsample), + msa_input_feats=( + subsampled_msa_input_feats + if subsampled_msa_input_feats is not None + else msa_input_feats + ), + msa_mask=( + subsampled_msa_mask if subsampled_msa_mask is not None else msa_mask + ), template_input_feats=template_input_feats, template_input_masks=template_input_masks, token_single_mask=token_single_mask, diff --git a/chai_lab/data/dataset/msas/utils.py b/chai_lab/data/dataset/msas/utils.py index b61d023..3f1c910 100644 --- a/chai_lab/data/dataset/msas/utils.py +++ b/chai_lab/data/dataset/msas/utils.py @@ -2,22 +2,25 @@ # Licensed under the Apache License, Version 2.0. # See the LICENSE file for details. +import logging + import torch -from einops import rearrange, reduce, repeat +from einops import rearrange, reduce from torch import Tensor +from torch.nn import functional as F -from chai_lab.utils.typing import Bool, typecheck +from chai_lab.utils.typing import Bool, Float, typecheck @typecheck -def subsample_msa_rows( +def _subsample_msa_rows( mask: Bool[Tensor, "1 depth tokens"], select_n_rows: int = 4096, generator: torch.Generator | None = None, -) -> Bool[Tensor, "1 depth tokens"]: +) -> Bool[Tensor, "depth"] | None: """Adjust masking to look at a random subset of msas. - Returns input mask as-is if select_n_rows <= 0 or depth < select_n_rows.""" + Returns None if select_n_rows <= 0 or depth < select_n_rows.""" # Count the number of non-padding residues in each row of the MSA msa_sizes = rearrange( reduce(mask, "b depth tok -> b depth", reduction="sum"), "1 depth -> depth" @@ -25,7 +28,7 @@ def subsample_msa_rows( nonnull_rows_mask = msa_sizes > 0 input_depth = nonnull_rows_mask.sum().item() if select_n_rows <= 0 or input_depth <= select_n_rows: - return mask + return None # Bias towards bigger hit MSAs; 0 size is automatically nulled out mask_ranking = msa_sizes * torch.rand( @@ -36,9 +39,48 @@ def subsample_msa_rows( ) # Ascending sort -> choose the last (highest scoring) rows selected_row_indices = mask_ranking.argsort()[-select_n_rows:] + # We should never sample empty MSA rows + assert not (~nonnull_rows_mask[selected_row_indices]).any() - # Create a mask for selected row indices selection_mask = torch.zeros_like(nonnull_rows_mask) selection_mask[selected_row_indices] = True + return selection_mask + - return mask & repeat(selection_mask, "d -> 1 d 1") +@typecheck +def subsample_and_reorder_msa_feats_n_mask( + feats: Float[Tensor, "1 depth tokens dim"], + mask: Bool[Tensor, "1 depth tokens"], + select_n_rows: int = 4096, + generator: torch.Generator | None = None, +) -> tuple[Float[Tensor, "1 depth tokens dim"], Bool[Tensor, "1 depth tokens"]]: + selection_mask = _subsample_msa_rows( + mask=mask, + select_n_rows=select_n_rows, + generator=generator, + ) + if selection_mask is None: # No subsampling + return feats, mask + + # Select the rows; where returns in order from top to bottom, preserving order + (selection_idx,) = torch.where(selection_mask) + logging.info(f"Subsampling {selection_idx.tolist()[:5]}...") + (unselected_idx,) = torch.where(~selection_mask) + combo_idx = torch.cat([selection_idx, unselected_idx]) + # Features are reordered, while mask is selected + padded + feats_sampled = torch.index_select(feats, dim=1, index=combo_idx) + mask_sampled = torch.index_select(mask, dim=1, index=selection_idx) + # Every sampled row should have nonzero coverage + assert mask_sampled.any(dim=-1).all() + + # Pad with zeros + _, orig_depth, _ = mask.shape + _, new_depth, _ = mask_sampled.shape + assert (n_pad := orig_depth - new_depth) > 0 + # Padding is last dim, moving forward, e.g., for last two dimensions, it is: + # (left, right, top, bottom) + # [0, 0, 0, n_pad] ignores the token dim and pads out the depth dim + return ( + feats_sampled, + F.pad(mask_sampled, pad=[0, 0, 0, n_pad], value=False), + )