From efb7b9b21388ea743bfbedc1448bfabbf4d3a206 Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Wed, 12 Feb 2025 16:38:18 +0000 Subject: [PATCH 1/5] Subsample via mask --- chai_lab/chai1.py | 6 ++++- chai_lab/data/dataset/msas/utils.py | 36 +++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) create mode 100644 chai_lab/data/dataset/msas/utils.py diff --git a/chai_lab/chai1.py b/chai_lab/chai1.py index e49a145..a96e195 100644 --- a/chai_lab/chai1.py +++ b/chai_lab/chai1.py @@ -32,6 +32,7 @@ from chai_lab.data.dataset.msas.colabfold import generate_colabfold_msas from chai_lab.data.dataset.msas.load import get_msa_contexts from chai_lab.data.dataset.msas.msa_context import MSAContext +from chai_lab.data.dataset.msas.utils import subsample_msa_rows from chai_lab.data.dataset.structure.all_atom_structure_context import ( AllAtomStructureContext, ) @@ -441,6 +442,7 @@ def run_inference( msa_directory: Path | None = None, constraint_path: Path | None = None, # expose some params for easy tweaking + recycle_msa_subsample: int = 0, num_trunk_recycles: int = 3, num_diffn_timesteps: int = 200, num_diffn_samples: int = 5, @@ -472,6 +474,7 @@ def run_inference( num_trunk_recycles=num_trunk_recycles, num_diffn_timesteps=num_diffn_timesteps, num_diffn_samples=num_diffn_samples, + recycle_msa_subsample=recycle_msa_subsample, seed=seed, device=torch_device, low_memory=low_memory, @@ -488,6 +491,7 @@ def run_folding_on_context( *, output_dir: Path, # expose some params for easy tweaking + recycle_msa_subsample: int = 0, num_trunk_recycles: int = 3, num_diffn_timesteps: int = 200, # all diffusion samples come from the same trunk @@ -647,7 +651,7 @@ def run_folding_on_context( token_single_trunk_repr=token_single_trunk_repr, # recycled token_pair_trunk_repr=token_pair_trunk_repr, # recycled msa_input_feats=msa_input_feats, - msa_mask=msa_mask, + msa_mask=subsample_msa_rows(msa_mask, select_n_rows=recycle_msa_subsample), template_input_feats=template_input_feats, template_input_masks=template_input_masks, token_single_mask=token_single_mask, diff --git a/chai_lab/data/dataset/msas/utils.py b/chai_lab/data/dataset/msas/utils.py new file mode 100644 index 0000000..6f88675 --- /dev/null +++ b/chai_lab/data/dataset/msas/utils.py @@ -0,0 +1,36 @@ +# Copyright (c) 2024 Chai Discovery, Inc. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for details. + +import torch +from einops import rearrange, repeat +from torch import Tensor + +from chai_lab.utils.typing import Bool + + +def subsample_msa_rows( + mask: Bool[Tensor, "1 depth tokens"], + select_n_rows: int = 4096, + generator: torch.Generator | None = None, +) -> Bool[Tensor, "1 depth tokens"]: + """Adjust masking to look at a random subset of msas. + + Returns input mask as-is if select_n_rows <= 0 or depth < select_n_rows.""" + nonnull_rows_mask = rearrange(mask.any(dim=-1), "1 d -> d") + input_depth = nonnull_rows_mask.sum().item() + if select_n_rows <= 0 or input_depth <= select_n_rows: + return mask + + # Select from rows of the MSA that are not fully masked out + (nonnull_row_indices,) = torch.where(nonnull_rows_mask) + assert (n := nonnull_row_indices.numel()) > select_n_rows + permuted = torch.randperm(n, device=mask.device, generator=generator) + selected_row_indices = nonnull_row_indices[permuted[:select_n_rows]] + + # Create a mask for selected row indices + selection_mask = torch.zeros_like(nonnull_rows_mask) + selection_mask[selected_row_indices] = True + selection_mask = repeat(selection_mask, "d -> 1 d 1") + + return mask & selection_mask From afca4a77707ce64defd4fcdecad5033faa3e7de1 Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Wed, 12 Feb 2025 17:47:44 +0000 Subject: [PATCH 2/5] Add typechecking --- chai_lab/data/dataset/msas/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/chai_lab/data/dataset/msas/utils.py b/chai_lab/data/dataset/msas/utils.py index 6f88675..fe05468 100644 --- a/chai_lab/data/dataset/msas/utils.py +++ b/chai_lab/data/dataset/msas/utils.py @@ -6,9 +6,10 @@ from einops import rearrange, repeat from torch import Tensor -from chai_lab.utils.typing import Bool +from chai_lab.utils.typing import Bool, typecheck +@typecheck def subsample_msa_rows( mask: Bool[Tensor, "1 depth tokens"], select_n_rows: int = 4096, From acc98910f83a545f6c06f44384ff451489972cd9 Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Wed, 12 Feb 2025 18:39:46 +0000 Subject: [PATCH 3/5] Native support for multiple trunk samples --- chai_lab/chai1.py | 48 ++++++++++++++++++++++++++++++++++++----------- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/chai_lab/chai1.py b/chai_lab/chai1.py index a96e195..1014db4 100644 --- a/chai_lab/chai1.py +++ b/chai_lab/chai1.py @@ -3,10 +3,13 @@ # See the LICENSE file for details. +import itertools +import logging import math from collections import Counter from dataclasses import dataclass from pathlib import Path +from typing import Sequence import numpy as np import torch @@ -296,6 +299,23 @@ def sorted(self) -> "StructureCandidates": plddt=self.plddt[idx], ) + @classmethod + def concat( + cls, candidates: Sequence["StructureCandidates"] + ) -> "StructureCandidates": + return cls( + cif_paths=list( + itertools.chain.from_iterable([c.cif_paths for c in candidates]) + ), + ranking_data=list( + itertools.chain.from_iterable([c.ranking_data for c in candidates]) + ), + msa_coverage_plot_path=candidates[0].msa_coverage_plot_path, + pae=torch.cat([c.pae for c in candidates]), + pde=torch.cat([c.pde for c in candidates]), + plddt=torch.cat([c.plddt for c in candidates]), + ) + def make_all_atom_feature_context( fasta_file: Path, @@ -446,6 +466,7 @@ def run_inference( num_trunk_recycles: int = 3, num_diffn_timesteps: int = 200, num_diffn_samples: int = 5, + num_trunk_samples: int = 5, seed: int | None = None, device: str | None = None, low_memory: bool = True, @@ -468,17 +489,22 @@ def run_inference( esm_device=torch_device, ) - return run_folding_on_context( - feature_context, - output_dir=output_dir, - num_trunk_recycles=num_trunk_recycles, - num_diffn_timesteps=num_diffn_timesteps, - num_diffn_samples=num_diffn_samples, - recycle_msa_subsample=recycle_msa_subsample, - seed=seed, - device=torch_device, - low_memory=low_memory, - ) + all_candidates: list[StructureCandidates] = [] + for trunk_idx in range(num_trunk_samples): + logging.info(f"Trunk sample {trunk_idx + 1}/{num_trunk_samples}") + cand = run_folding_on_context( + feature_context, + output_dir=output_dir / f"trunk_{trunk_idx}", + num_trunk_recycles=num_trunk_recycles, + num_diffn_timesteps=num_diffn_timesteps, + num_diffn_samples=num_diffn_samples, + recycle_msa_subsample=recycle_msa_subsample, + seed=seed + trunk_idx if seed is not None else None, + device=torch_device, + low_memory=low_memory, + ) + all_candidates.append(cand) + return StructureCandidates.concat(all_candidates) def _bin_centers(min_bin: float, max_bin: float, no_bins: int) -> Tensor: From e2ac27e74ab8696b0fa96d8ddd56905d56ec5e0f Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Wed, 12 Feb 2025 20:51:14 +0000 Subject: [PATCH 4/5] Default to one trunk --- chai_lab/chai1.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/chai_lab/chai1.py b/chai_lab/chai1.py index 1014db4..3cafa40 100644 --- a/chai_lab/chai1.py +++ b/chai_lab/chai1.py @@ -466,7 +466,7 @@ def run_inference( num_trunk_recycles: int = 3, num_diffn_timesteps: int = 200, num_diffn_samples: int = 5, - num_trunk_samples: int = 5, + num_trunk_samples: int = 1, seed: int | None = None, device: str | None = None, low_memory: bool = True, @@ -494,7 +494,11 @@ def run_inference( logging.info(f"Trunk sample {trunk_idx + 1}/{num_trunk_samples}") cand = run_folding_on_context( feature_context, - output_dir=output_dir / f"trunk_{trunk_idx}", + output_dir=( + output_dir / f"trunk_{trunk_idx}" + if num_trunk_samples > 1 + else output_dir + ), num_trunk_recycles=num_trunk_recycles, num_diffn_timesteps=num_diffn_timesteps, num_diffn_samples=num_diffn_samples, From 815b8217c85b5622c5d42d7d5d3fd5a8268b174e Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Wed, 12 Feb 2025 21:02:01 +0000 Subject: [PATCH 5/5] Check --- chai_lab/chai1.py | 1 + 1 file changed, 1 insertion(+) diff --git a/chai_lab/chai1.py b/chai_lab/chai1.py index 3cafa40..6c251e3 100644 --- a/chai_lab/chai1.py +++ b/chai_lab/chai1.py @@ -471,6 +471,7 @@ def run_inference( device: str | None = None, low_memory: bool = True, ) -> StructureCandidates: + assert num_trunk_samples > 0 and num_diffn_samples > 0 if output_dir.exists(): assert not any( output_dir.iterdir()