Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to MSA subsampling #305

Merged
merged 5 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions chai_lab/chai1.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,9 +402,7 @@ def make_all_atom_feature_context(

# Load templates
if templates_path is None:
assert (
not use_templates_server
), "Templates path should never be none when querying server for templates"
assert not use_templates_server, "Server should have written a path"
template_context = TemplateContext.empty(
n_tokens=n_actual_tokens,
n_templates=MAX_NUM_TEMPLATES,
Expand Down
2 changes: 1 addition & 1 deletion chai_lab/data/dataset/msas/msa_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __getitem__(self, subscript: tuple) -> "MSAContext":
mask=self.mask[subscript],
)

def take_rows_with_padding(self, row_indices_with_nones: list):
def take_rows_with_padding(self, row_indices_with_nones: list[int | None]):
"""
allows specifying index=None, which will be filled with empty sequence,
helpful to align multiple sequences
Expand Down
25 changes: 16 additions & 9 deletions chai_lab/data/dataset/msas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# See the LICENSE file for details.

import torch
from einops import rearrange, repeat
from einops import rearrange, reduce, repeat
from torch import Tensor

from chai_lab.utils.typing import Bool, typecheck
Expand All @@ -18,20 +18,27 @@ def subsample_msa_rows(
"""Adjust masking to look at a random subset of msas.

Returns input mask as-is if select_n_rows <= 0 or depth < select_n_rows."""
nonnull_rows_mask = rearrange(mask.any(dim=-1), "1 d -> d")
# Count the number of non-padding residues in each row of the MSA
msa_sizes = rearrange(
reduce(mask, "b depth tok -> b depth", reduction="sum"), "1 depth -> depth"
)
nonnull_rows_mask = msa_sizes > 0
input_depth = nonnull_rows_mask.sum().item()
if select_n_rows <= 0 or input_depth <= select_n_rows:
return mask

# Select from rows of the MSA that are not fully masked out
(nonnull_row_indices,) = torch.where(nonnull_rows_mask)
assert (n := nonnull_row_indices.numel()) > select_n_rows
permuted = torch.randperm(n, device=mask.device, generator=generator)
selected_row_indices = nonnull_row_indices[permuted[:select_n_rows]]
# Bias towards bigger hit MSAs; 0 size is automatically nulled out
mask_ranking = msa_sizes * torch.rand(
size=msa_sizes.shape,
dtype=torch.float16,
device=msa_sizes.device,
generator=generator,
)
# Ascending sort -> choose the last (highest scoring) rows
selected_row_indices = mask_ranking.argsort()[-select_n_rows:]

# Create a mask for selected row indices
selection_mask = torch.zeros_like(nonnull_rows_mask)
selection_mask[selected_row_indices] = True
selection_mask = repeat(selection_mask, "d -> 1 d 1")

return mask & selection_mask
return mask & repeat(selection_mask, "d -> 1 d 1")