Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for subsampling different MSAs at each recycle and running multiple trunks #302

Merged
merged 5 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 46 additions & 11 deletions chai_lab/chai1.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
# See the LICENSE file for details.


import itertools
import logging
import math
from collections import Counter
from dataclasses import dataclass
from pathlib import Path
from typing import Sequence

import numpy as np
import torch
Expand All @@ -32,6 +35,7 @@
from chai_lab.data.dataset.msas.colabfold import generate_colabfold_msas
from chai_lab.data.dataset.msas.load import get_msa_contexts
from chai_lab.data.dataset.msas.msa_context import MSAContext
from chai_lab.data.dataset.msas.utils import subsample_msa_rows
from chai_lab.data.dataset.structure.all_atom_structure_context import (
AllAtomStructureContext,
)
Expand Down Expand Up @@ -295,6 +299,23 @@ def sorted(self) -> "StructureCandidates":
plddt=self.plddt[idx],
)

@classmethod
def concat(
cls, candidates: Sequence["StructureCandidates"]
) -> "StructureCandidates":
return cls(
cif_paths=list(
itertools.chain.from_iterable([c.cif_paths for c in candidates])
),
ranking_data=list(
itertools.chain.from_iterable([c.ranking_data for c in candidates])
),
msa_coverage_plot_path=candidates[0].msa_coverage_plot_path,
pae=torch.cat([c.pae for c in candidates]),
pde=torch.cat([c.pde for c in candidates]),
plddt=torch.cat([c.plddt for c in candidates]),
)


def make_all_atom_feature_context(
fasta_file: Path,
Expand Down Expand Up @@ -441,13 +462,16 @@ def run_inference(
msa_directory: Path | None = None,
constraint_path: Path | None = None,
# expose some params for easy tweaking
recycle_msa_subsample: int = 0,
num_trunk_recycles: int = 3,
num_diffn_timesteps: int = 200,
num_diffn_samples: int = 5,
num_trunk_samples: int = 1,
seed: int | None = None,
device: str | None = None,
low_memory: bool = True,
) -> StructureCandidates:
assert num_trunk_samples > 0 and num_diffn_samples > 0
if output_dir.exists():
assert not any(
output_dir.iterdir()
Expand All @@ -466,16 +490,26 @@ def run_inference(
esm_device=torch_device,
)

return run_folding_on_context(
feature_context,
output_dir=output_dir,
num_trunk_recycles=num_trunk_recycles,
num_diffn_timesteps=num_diffn_timesteps,
num_diffn_samples=num_diffn_samples,
seed=seed,
device=torch_device,
low_memory=low_memory,
)
all_candidates: list[StructureCandidates] = []
for trunk_idx in range(num_trunk_samples):
logging.info(f"Trunk sample {trunk_idx + 1}/{num_trunk_samples}")
cand = run_folding_on_context(
feature_context,
output_dir=(
output_dir / f"trunk_{trunk_idx}"
if num_trunk_samples > 1
else output_dir
),
num_trunk_recycles=num_trunk_recycles,
num_diffn_timesteps=num_diffn_timesteps,
num_diffn_samples=num_diffn_samples,
recycle_msa_subsample=recycle_msa_subsample,
seed=seed + trunk_idx if seed is not None else None,
device=torch_device,
low_memory=low_memory,
)
all_candidates.append(cand)
return StructureCandidates.concat(all_candidates)


def _bin_centers(min_bin: float, max_bin: float, no_bins: int) -> Tensor:
Expand All @@ -488,6 +522,7 @@ def run_folding_on_context(
*,
output_dir: Path,
# expose some params for easy tweaking
recycle_msa_subsample: int = 0,
num_trunk_recycles: int = 3,
num_diffn_timesteps: int = 200,
# all diffusion samples come from the same trunk
Expand Down Expand Up @@ -647,7 +682,7 @@ def run_folding_on_context(
token_single_trunk_repr=token_single_trunk_repr, # recycled
token_pair_trunk_repr=token_pair_trunk_repr, # recycled
msa_input_feats=msa_input_feats,
msa_mask=msa_mask,
msa_mask=subsample_msa_rows(msa_mask, select_n_rows=recycle_msa_subsample),
template_input_feats=template_input_feats,
template_input_masks=template_input_masks,
token_single_mask=token_single_mask,
Expand Down
37 changes: 37 additions & 0 deletions chai_lab/data/dataset/msas/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) 2024 Chai Discovery, Inc.
# Licensed under the Apache License, Version 2.0.
# See the LICENSE file for details.

import torch
from einops import rearrange, repeat
from torch import Tensor

from chai_lab.utils.typing import Bool, typecheck


@typecheck
def subsample_msa_rows(
mask: Bool[Tensor, "1 depth tokens"],
select_n_rows: int = 4096,
generator: torch.Generator | None = None,
) -> Bool[Tensor, "1 depth tokens"]:
"""Adjust masking to look at a random subset of msas.

Returns input mask as-is if select_n_rows <= 0 or depth < select_n_rows."""
nonnull_rows_mask = rearrange(mask.any(dim=-1), "1 d -> d")
input_depth = nonnull_rows_mask.sum().item()
if select_n_rows <= 0 or input_depth <= select_n_rows:
return mask

# Select from rows of the MSA that are not fully masked out
(nonnull_row_indices,) = torch.where(nonnull_rows_mask)
assert (n := nonnull_row_indices.numel()) > select_n_rows
permuted = torch.randperm(n, device=mask.device, generator=generator)
selected_row_indices = nonnull_row_indices[permuted[:select_n_rows]]

# Create a mask for selected row indices
selection_mask = torch.zeros_like(nonnull_rows_mask)
selection_mask[selected_row_indices] = True
selection_mask = repeat(selection_mask, "d -> 1 d 1")

return mask & selection_mask