Skip to content

Commit

Permalink
Even more improvements to MSA subsampling (#306)
Browse files Browse the repository at this point in the history
  • Loading branch information
wukevin authored Feb 16, 2025
1 parent 169dd4d commit 172b666
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 11 deletions.
22 changes: 19 additions & 3 deletions chai_lab/chai1.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
from chai_lab.data.dataset.msas.colabfold import generate_colabfold_msas
from chai_lab.data.dataset.msas.load import get_msa_contexts
from chai_lab.data.dataset.msas.msa_context import MSAContext
from chai_lab.data.dataset.msas.utils import subsample_msa_rows
from chai_lab.data.dataset.msas.utils import (
subsample_and_reorder_msa_feats_n_mask,
)
from chai_lab.data.dataset.structure.all_atom_structure_context import (
AllAtomStructureContext,
)
Expand Down Expand Up @@ -705,14 +707,28 @@ def run_folding_on_context(
token_single_trunk_repr = token_single_initial_repr
token_pair_trunk_repr = token_pair_initial_repr
for _ in tqdm(range(num_trunk_recycles), desc="Trunk recycles"):
subsampled_msa_input_feats, subsampled_msa_mask = None, None
if recycle_msa_subsample > 0:
subsampled_msa_input_feats, subsampled_msa_mask = (
subsample_and_reorder_msa_feats_n_mask(
msa_input_feats,
msa_mask,
)
)
(token_single_trunk_repr, token_pair_trunk_repr) = trunk.forward(
move_to_device=device,
token_single_trunk_initial_repr=token_single_initial_repr,
token_pair_trunk_initial_repr=token_pair_initial_repr,
token_single_trunk_repr=token_single_trunk_repr, # recycled
token_pair_trunk_repr=token_pair_trunk_repr, # recycled
msa_input_feats=msa_input_feats,
msa_mask=subsample_msa_rows(msa_mask, select_n_rows=recycle_msa_subsample),
msa_input_feats=(
subsampled_msa_input_feats
if subsampled_msa_input_feats is not None
else msa_input_feats
),
msa_mask=(
subsampled_msa_mask if subsampled_msa_mask is not None else msa_mask
),
template_input_feats=template_input_feats,
template_input_masks=template_input_masks,
token_single_mask=token_single_mask,
Expand Down
58 changes: 50 additions & 8 deletions chai_lab/data/dataset/msas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,33 @@
# Licensed under the Apache License, Version 2.0.
# See the LICENSE file for details.

import logging

import torch
from einops import rearrange, reduce, repeat
from einops import rearrange, reduce
from torch import Tensor
from torch.nn import functional as F

from chai_lab.utils.typing import Bool, typecheck
from chai_lab.utils.typing import Bool, Float, typecheck


@typecheck
def subsample_msa_rows(
def _subsample_msa_rows(
mask: Bool[Tensor, "1 depth tokens"],
select_n_rows: int = 4096,
generator: torch.Generator | None = None,
) -> Bool[Tensor, "1 depth tokens"]:
) -> Bool[Tensor, "depth"] | None:
"""Adjust masking to look at a random subset of msas.
Returns input mask as-is if select_n_rows <= 0 or depth < select_n_rows."""
Returns None if select_n_rows <= 0 or depth < select_n_rows."""
# Count the number of non-padding residues in each row of the MSA
msa_sizes = rearrange(
reduce(mask, "b depth tok -> b depth", reduction="sum"), "1 depth -> depth"
)
nonnull_rows_mask = msa_sizes > 0
input_depth = nonnull_rows_mask.sum().item()
if select_n_rows <= 0 or input_depth <= select_n_rows:
return mask
return None

# Bias towards bigger hit MSAs; 0 size is automatically nulled out
mask_ranking = msa_sizes * torch.rand(
Expand All @@ -36,9 +39,48 @@ def subsample_msa_rows(
)
# Ascending sort -> choose the last (highest scoring) rows
selected_row_indices = mask_ranking.argsort()[-select_n_rows:]
# We should never sample empty MSA rows
assert not (~nonnull_rows_mask[selected_row_indices]).any()

# Create a mask for selected row indices
selection_mask = torch.zeros_like(nonnull_rows_mask)
selection_mask[selected_row_indices] = True
return selection_mask


return mask & repeat(selection_mask, "d -> 1 d 1")
@typecheck
def subsample_and_reorder_msa_feats_n_mask(
feats: Float[Tensor, "1 depth tokens dim"],
mask: Bool[Tensor, "1 depth tokens"],
select_n_rows: int = 4096,
generator: torch.Generator | None = None,
) -> tuple[Float[Tensor, "1 depth tokens dim"], Bool[Tensor, "1 depth tokens"]]:
selection_mask = _subsample_msa_rows(
mask=mask,
select_n_rows=select_n_rows,
generator=generator,
)
if selection_mask is None: # No subsampling
return feats, mask

# Select the rows; where returns in order from top to bottom, preserving order
(selection_idx,) = torch.where(selection_mask)
logging.info(f"Subsampling {selection_idx.tolist()[:5]}...")
(unselected_idx,) = torch.where(~selection_mask)
combo_idx = torch.cat([selection_idx, unselected_idx])
# Features are reordered, while mask is selected + padded
feats_sampled = torch.index_select(feats, dim=1, index=combo_idx)
mask_sampled = torch.index_select(mask, dim=1, index=selection_idx)
# Every sampled row should have nonzero coverage
assert mask_sampled.any(dim=-1).all()

# Pad with zeros
_, orig_depth, _ = mask.shape
_, new_depth, _ = mask_sampled.shape
assert (n_pad := orig_depth - new_depth) > 0
# Padding is last dim, moving forward, e.g., for last two dimensions, it is:
# (left, right, top, bottom)
# [0, 0, 0, n_pad] ignores the token dim and pads out the depth dim
return (
feats_sampled,
F.pad(mask_sampled, pad=[0, 0, 0, n_pad], value=False),
)

0 comments on commit 172b666

Please sign in to comment.