diff --git a/README.md b/README.md
index 11fcbfb..d01ac21 100644
--- a/README.md
+++ b/README.md
@@ -86,7 +86,7 @@ For user convenience, we also support automatic MSA generation via the ColabFold
-For more advanced use cases, we also expose the `chai_lab.chai1.run_folding_on_context`, which allows users to construct an `AllAtomFeatureContext` manually. This allows users to specify their own templates, MSAs, embeddings, and constraints. We currently provide an example of how to construct an embeddings context as well as an MSA context, and will be releasing helper methods to build template contexts soon. +For more advanced use cases, we also expose the `chai_lab.chai1.run_folding_on_context`, which allows users to construct an `AllAtomFeatureContext` manually. This allows users to specify their own templates, MSAs, embeddings, and constraints, including support for specifying covalent bonds (for example, for specifying branched ligands). We currently provide examples of how to construct an embeddings context, an MSA context, restraint contexts, and covalent bonds. We will be releasing helper methods to build template contexts soon.
diff --git a/chai_lab/chai1.py b/chai_lab/chai1.py index 7f8f2f9..3388c9d 100644 --- a/chai_lab/chai1.py +++ b/chai_lab/chai1.py @@ -35,6 +35,9 @@ from chai_lab.data.dataset.structure.all_atom_structure_context import ( AllAtomStructureContext, ) +from chai_lab.data.dataset.structure.bond_utils import ( + get_atom_covalent_bond_pairs_from_constraints, +) from chai_lab.data.dataset.templates.context import TemplateContext from chai_lab.data.features.feature_factory import FeatureFactory from chai_lab.data.features.feature_type import FeatureType @@ -76,6 +79,7 @@ TemplateResTypeGenerator, TemplateUnitVectorGenerator, ) +from chai_lab.data.features.generators.token_bond import TokenBondRestraint from chai_lab.data.features.generators.token_dist_restraint import ( TokenDistanceRestraint, ) @@ -370,11 +374,32 @@ def run_inference( # Constraints if constraint_path is not None: + # Handles contact and pocket restraints + pairs = parse_pairwise_table(constraint_path) restraint_context = load_manual_restraints_for_chai1( chains, crop_idces=None, - provided_constraints=parse_pairwise_table(constraint_path), + provided_constraints=pairs, + ) + # Handle covalent bond restraints + cov_a, cov_b = get_atom_covalent_bond_pairs_from_constraints( + provided_constraints=pairs, + token_residue_index=merged_context.token_residue_index, + token_residue_name=merged_context.token_residue_name, + token_subchain_id=merged_context.subchain_id, + token_asym_id=merged_context.token_asym_id, + atom_token_index=merged_context.atom_token_index, + atom_ref_name=merged_context.atom_ref_name, ) + if cov_a.numel() > 0 and cov_b.numel() > 0: + orig_a, orig_b = merged_context.atom_covalent_bond_indices + if orig_a.numel() == orig_b.numel() == 0: + merged_context.atom_covalent_bond_indices = (orig_a, orig_b) + else: + merged_context.atom_covalent_bond_indices = ( + torch.concatenate([orig_a, cov_a]), + torch.concatenate([orig_b, cov_b]), + ) else: restraint_context = RestraintContext.empty() @@ -437,6 +462,7 @@ def run_folding_on_context( raise_if_too_many_templates(feature_context.template_context.num_templates) raise_if_msa_too_deep(feature_context.msa_context.depth) # NOTE profile MSA used only for statistics; no depth check + feature_context.structure_context.report_bonds() ## ## Prepare batch @@ -480,6 +506,7 @@ def run_folding_on_context( assert model_size in AVAILABLE_MODEL_SIZES feature_embedding = load_exported("feature_embedding.pt", device) + bond_loss_input_proj = load_exported("bond_loss_input_proj.pt", device) token_input_embedder = load_exported("token_embedder.pt", device) trunk = load_exported("trunk.pt", device) diffusion_module = load_exported("diffusion_module.pt", device) @@ -503,6 +530,19 @@ def run_folding_on_context( template_input_feats = embedded_features["TEMPLATES"] msa_input_feats = embedded_features["MSA"] + ## + ## Bond feature generator + ## Separate from other feature embeddings due to export limitations + ## + + bond_ft_gen = TokenBondRestraint() + bond_ft = bond_ft_gen.generate(batch=batch).data + trunk_bond_feat, structure_bond_feat = bond_loss_input_proj.forward( + crop_size=model_size, input=bond_ft + ).chunk(2, dim=-1) + token_pair_input_feats += trunk_bond_feat + token_pair_structure_input_feats += structure_bond_feat + ## ## Run the inputs through the token input embedder ## diff --git a/chai_lab/data/dataset/constraints/restraint_context.py b/chai_lab/data/dataset/constraints/restraint_context.py index a24f7a9..28348a2 100644 --- a/chai_lab/data/dataset/constraints/restraint_context.py +++ b/chai_lab/data/dataset/constraints/restraint_context.py @@ -98,7 +98,7 @@ def load_manual_restraints_for_chai1( contact_constraints: list[ContactRestraint] = [] pocket_constraints: list[PocketRestraint] = [] - logger.info(f"Loading {len(provided_constraints)} constraints...") + logger.info(f"Loading {len(provided_constraints)} restraints...") for constraint in provided_constraints: match ctype := constraint.connection_type: case PairwiseInteractionType.COVALENT: diff --git a/chai_lab/data/dataset/inference_dataset.py b/chai_lab/data/dataset/inference_dataset.py index 5caaa55..400aa8e 100644 --- a/chai_lab/data/dataset/inference_dataset.py +++ b/chai_lab/data/dataset/inference_dataset.py @@ -19,6 +19,7 @@ ) from chai_lab.data.dataset.structure.chain import Chain from chai_lab.data.parsing.fasta import get_residue_name, read_fasta +from chai_lab.data.parsing.glycans import glycan_string_residues from chai_lab.data.parsing.input_validation import ( constituents_of_modified_fasta, identify_potential_entity_types, @@ -118,6 +119,8 @@ def raw_inputs_to_entitites_data( for r in parsed_sequence ] residues = get_polymer_residues(expanded_sequence, entity_type) + case EntityType.MANUAL_GLYCAN: + residues = glycan_string_residues(glycan_string=input.sequence) case _: raise NotImplementedError assert residues is not None @@ -145,6 +148,7 @@ def raw_inputs_to_entitites_data( method="none", entity_type=entity_type, subchain_id=_synth_subchain_id(i), + original_record=input.sequence, ) ) @@ -232,6 +236,8 @@ def read_inputs(fasta_file: str | Path, length_limit: int | None = None) -> list entity_type = EntityType.RNA case "dna": entity_type = EntityType.DNA + case "glycan": + entity_type = EntityType.MANUAL_GLYCAN case _: raise ValueError(f"{entity_str} is not a valid entity type") diff --git a/chai_lab/data/dataset/structure/all_atom_residue_tokenizer.py b/chai_lab/data/dataset/structure/all_atom_residue_tokenizer.py index 1aceed9..5078121 100644 --- a/chai_lab/data/dataset/structure/all_atom_residue_tokenizer.py +++ b/chai_lab/data/dataset/structure/all_atom_residue_tokenizer.py @@ -14,6 +14,9 @@ from chai_lab.data.dataset.structure.all_atom_structure_context import ( AllAtomStructureContext, ) +from chai_lab.data.dataset.structure.bond_utils import ( + get_atom_covalent_bond_pairs_from_glycan_string, +) from chai_lab.data.dataset.structure.utils import ( backbone_atoms_all_present, backbone_atoms_indices, @@ -510,6 +513,15 @@ def _tokenize_entity( dtype=torch.bool, ), symmetries=tokens.symmetries, + atom_covalent_bond_indices=get_atom_covalent_bond_pairs_from_glycan_string( + glycan_string=( + entity_data.original_record + if entity_data.entity_type == EntityType.MANUAL_GLYCAN + else "" + ), + token_residue_index=tokens.residue_index, + atom_ref_name=tokens.atom_names, + ), ) def _get_ref_conformer_data(self, residue: Residue) -> ConformerData: diff --git a/chai_lab/data/dataset/structure/all_atom_structure_context.py b/chai_lab/data/dataset/structure/all_atom_structure_context.py index 1879749..ab02276 100644 --- a/chai_lab/data/dataset/structure/all_atom_structure_context.py +++ b/chai_lab/data/dataset/structure/all_atom_structure_context.py @@ -65,6 +65,8 @@ class AllAtomStructureContext: is_distillation: Bool[Tensor, "1"] # symmetric atom swap indices symmetries: Int[Tensor, "n_atoms n_symmetries"] + # atom-wise bond feature; corresponding lists of atoms that are covalently bound + atom_covalent_bond_indices: tuple[Int[Tensor, "n_bonds"], Int[Tensor, "n_bonds"]] def __post_init__(self): # Resolved residues filter should eliminate PDBs with missing residues, but that @@ -82,10 +84,29 @@ def __post_init__(self): pdb_id = tensorcode_to_string(self.pdb_id[0]) logger.error(f"Incompatible masks for {pdb_id}") + # Check that bonds are specified in atom space + assert torch.all(self.atom_covalent_bond_indices[0] < self.num_atoms) + assert torch.all(self.atom_covalent_bond_indices[1] < self.num_atoms) + @cached_property def residue_names(self) -> list[str]: return batch_tensorcode_to_string(self.token_residue_name) + def report_bonds(self) -> None: + """Log information about covalent bonds.""" + for i, (atom_a, atom_b) in enumerate(zip(*self.atom_covalent_bond_indices)): + tok_a = self.atom_token_index[atom_a] + tok_b = self.atom_token_index[atom_b] + asym_a = self.token_asym_id[tok_a] + asym_b = self.token_asym_id[tok_b] + res_idx_a = self.token_residue_index[tok_a] + res_idx_b = self.token_residue_index[tok_b] + resname_a = tensorcode_to_string(self.token_residue_name[tok_a]) + resname_b = tensorcode_to_string(self.token_residue_name[tok_b]) + logging.info( + f"Bond {i} (asym res_idx resname): {asym_a} {res_idx_a} {resname_a} <> {asym_b} {res_idx_b} {resname_b}" + ) + def pad( self, n_tokens: int, @@ -142,6 +163,7 @@ def pad( resolution=self.resolution, is_distillation=self.is_distillation, symmetries=pad_atoms_func(self.symmetries, pad_value=-1), + atom_covalent_bond_indices=self.atom_covalent_bond_indices, ) @typecheck @@ -177,6 +199,30 @@ def merge( n_tokens = sum(x.num_tokens for x in contexts) token_index = torch.arange(n_tokens, dtype=torch.int) + # Merge and offset bond indices, which are indexed by *token* + atom_covalent_bond_indices_manual_a = [] + atom_covalent_bond_indices_manual_b = [] + for ctx, count in zip(contexts, atom_offsets): + if ctx.atom_covalent_bond_indices is None: + continue + a, b = ctx.atom_covalent_bond_indices + atom_covalent_bond_indices_manual_a.append(a + count) + atom_covalent_bond_indices_manual_b.append(b + count) + assert len(atom_covalent_bond_indices_manual_a) == len( + atom_covalent_bond_indices_manual_b + ) + atom_covalent_bond_indices = ( + ( + torch.concatenate(atom_covalent_bond_indices_manual_a), + torch.concatenate(atom_covalent_bond_indices_manual_b), + ) + if atom_covalent_bond_indices_manual_a + else ( + torch.zeros(0, dtype=torch.long), + torch.zeros(0, dtype=torch.long), + ) + ) + # re-index the reference space from 0..n_tokens-1. zero_indexed_ref_uids = [ torch.unique_consecutive(x.atom_ref_space_uid, return_inverse=True)[1] @@ -255,6 +301,7 @@ def merge( torch.stack([x.is_distillation for x in contexts]), 0 ).values, symmetries=symmetries, + atom_covalent_bond_indices=atom_covalent_bond_indices, ) def to(self, device: torch.device | str) -> "AllAtomStructureContext": diff --git a/chai_lab/data/dataset/structure/bond_utils.py b/chai_lab/data/dataset/structure/bond_utils.py new file mode 100644 index 0000000..3830643 --- /dev/null +++ b/chai_lab/data/dataset/structure/bond_utils.py @@ -0,0 +1,168 @@ +# Copyright (c) 2024 Chai Discovery, Inc. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for details. + +import torch +from einops import rearrange +from torch import Tensor + +from chai_lab.data import residue_constants as rc +from chai_lab.data.parsing.glycans import _glycan_string_to_sugars_and_bonds +from chai_lab.data.parsing.restraints import ( + PairwiseInteraction, + PairwiseInteractionType, +) +from chai_lab.model.utils import get_asym_id_from_subchain_id +from chai_lab.utils.tensor_utils import string_to_tensorcode +from chai_lab.utils.typing import Int, UInt8, typecheck + + +@typecheck +def get_atom_covalent_bond_pairs_from_constraints( + provided_constraints: list[PairwiseInteraction], + token_residue_index: Int[Tensor, "n_tokens"], + token_residue_name: UInt8[Tensor, "n_tokens 8"], + token_subchain_id: UInt8[Tensor, "n_tokens 4"], + token_asym_id: Int[Tensor, "n_tokens"], + atom_token_index: Int[Tensor, "n_atoms"], + atom_ref_name: list[str], +) -> tuple[Int[Tensor, "n_bonds"], Int[Tensor, "n_bonds"]]: + """Determine bond pairs, which are returned as atom indices that are bonded.""" + ret_a: list[int] = [] + ret_b: list[int] = [] + for constraint in provided_constraints: + match ctype := constraint.connection_type: + case PairwiseInteractionType.COVALENT: + assert ( + constraint.atom_nameA and constraint.atom_nameB + ), "Atoms must be provided for covalent bonds" + # Figure out the asym id that we care about + left_asym_id = get_asym_id_from_subchain_id( + subchain_id=constraint.chainA, + source_pdb_chain_id=token_subchain_id, + token_asym_id=token_asym_id, + ) + left_token_asym_mask = token_asym_id == left_asym_id + right_asym_id = get_asym_id_from_subchain_id( + subchain_id=constraint.chainB, + source_pdb_chain_id=token_subchain_id, + token_asym_id=token_asym_id, + ) + right_token_asym_mask = token_asym_id == right_asym_id + assert torch.any(left_token_asym_mask) and torch.any( + right_token_asym_mask + ) + + # Get the token index that we want + left_token_index_mask = ( + token_residue_index == constraint.res_idxA_pos - 1 + ) + right_token_index_mask = ( + token_residue_index == constraint.res_idxB_pos - 1 + ) + assert torch.any(left_token_index_mask) and torch.any( + right_token_index_mask + ) + + # Combine these to get the specific residue specified + left_residue_mask = left_token_asym_mask & left_token_index_mask + if constraint.res_idxA_name: + three_letter = string_to_tensorcode( + rc.restype_1to3.get(constraint.res_idxA_name, "UNK"), + pad_to_length=token_residue_name.shape[-1], + ) + resname_matches = ( + token_residue_name == rearrange(three_letter, "d -> 1 d") + ).all(dim=-1) + assert resname_matches.shape == left_residue_mask.shape + left_residue_mask &= resname_matches + right_residue_mask = right_token_asym_mask & right_token_index_mask + if constraint.res_idxB_name: + three_letter = string_to_tensorcode( + rc.restype_1to3.get(constraint.res_idxB_name, "UNK"), + pad_to_length=token_residue_name.shape[-1], + ) + resname_matches = ( + token_residue_name == rearrange(three_letter, "d -> 1 d") + ).all(dim=-1) + assert resname_matches.shape == right_residue_mask.shape + right_residue_mask &= resname_matches + # NOTE there are multiple residues in these residue masks due to + # per-atom tokenization of glycans + # These indices do not reset for new chains (matching atom_token_index) + left_residue_idx = torch.where(left_residue_mask)[0] + right_residue_idx = torch.where(right_residue_mask)[0] + assert left_residue_idx.numel() > 0 and right_residue_idx.numel() > 0 + + # Find the atoms belonging to this residue + left_atoms_mask = torch.isin( + atom_token_index, test_elements=left_residue_idx + ) + right_atoms_mask = torch.isin( + atom_token_index, test_elements=right_residue_idx + ) + assert torch.any(left_atoms_mask) and torch.any(right_atoms_mask) + + # Find atoms matching on atom name + left_name_mask = torch.tensor( + [n == constraint.atom_nameA for n in atom_ref_name], + dtype=torch.bool, + ) + right_name_mask = torch.tensor( + [n == constraint.atom_nameB for n in atom_ref_name], + dtype=torch.bool, + ) + + left_atom_mask = left_atoms_mask & left_name_mask + right_atom_mask = right_atoms_mask & right_name_mask + assert ( + torch.sum(left_atom_mask) == torch.sum(right_atom_mask) == 1 + ), f"Expect single atoms, got {torch.sum(left_atom_mask)}, {torch.sum(right_atom_mask)}" + + (left_atom_idx,) = torch.where(left_atom_mask) + (right_atom_idx,) = torch.where(right_atom_mask) + ret_a.append(left_atom_idx.item()) # type: ignore + ret_b.append(right_atom_idx.item()) # type: ignore + + case PairwiseInteractionType.CONTACT | PairwiseInteractionType.POCKET: + # These are handled as constraints, not as bonds + pass + case _: + raise ValueError(f"Unrecognized pariwise interaction: {ctype}") + return torch.tensor(ret_a, dtype=torch.int), torch.tensor(ret_b, dtype=torch.int) + + +@typecheck +def get_atom_covalent_bond_pairs_from_glycan_string( + glycan_string: str, + token_residue_index: Int[Tensor, "n_tokens"], + atom_ref_name: list[str], +) -> tuple[Int[Tensor, "n_bonds"], Int[Tensor, "n_bonds"]]: + """Infer bond pairs between glycans sugar rings.""" + if not glycan_string: + return ( + torch.zeros(0, dtype=torch.long), + torch.zeros(0, dtype=torch.long), + ) + + assert token_residue_index.numel() == len(atom_ref_name) + _sugars, bonds = _glycan_string_to_sugars_and_bonds(glycan_string) + left_bonds, right_bonds = [], [] + for bond in bonds: + left_chain_mask = token_residue_index == bond.src_sugar_index + left_res_mask = [n == bond.src_atom_name for n in atom_ref_name] + right_chain_mask = token_residue_index == bond.dst_sugar_index + right_res_mask = [n == bond.dst_atom_name for n in atom_ref_name] + left_res = left_chain_mask & torch.tensor(left_res_mask) + right_res = right_chain_mask & torch.tensor(right_res_mask) + assert left_res.sum() == 1, f"Expected unique atom, got {left_res=}" + assert right_res.sum() == 1, f"Expected unique atom, got {right_res=}" + left_res_idx, *_ = torch.where(left_res) + right_res_idx, *_ = torch.where(right_res) + left_bonds.append(left_res_idx.item()) + right_bonds.append(right_res_idx.item()) + bonds_to_add = ( + torch.tensor(left_bonds, dtype=torch.int), + torch.tensor(right_bonds, dtype=torch.int), + ) + return bonds_to_add diff --git a/chai_lab/data/features/generators/token_bond.py b/chai_lab/data/features/generators/token_bond.py new file mode 100644 index 0000000..6d7299e --- /dev/null +++ b/chai_lab/data/features/generators/token_bond.py @@ -0,0 +1,63 @@ +# Copyright (c) 2024 Chai Discovery, Inc. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for details. + +from typing import Any + +import torch +from torch import Tensor + +from chai_lab.data.features.feature_type import FeatureType +from chai_lab.data.features.generators.base import EncodingType, FeatureGenerator +from chai_lab.utils.tensor_utils import und_self +from chai_lab.utils.typing import Bool, Int, typecheck + + +class TokenBondRestraint(FeatureGenerator): + def __init__(self): + """Generates features for covalent bonds between atoms.""" + super().__init__( + ty=FeatureType.TOKEN_PAIR, + encoding_ty=EncodingType.IDENTITY, + can_mask=False, + num_classes=1, + mult=1, + ) + + def get_input_kwargs_from_batch(self, batch: dict[str, Any]) -> dict: + return dict( + token_exists_mask=batch["inputs"]["token_exists_mask"], + atom_token_index=batch["inputs"]["atom_token_index"].long(), + atom_covalent_bond_indices=batch["inputs"]["atom_covalent_bond_indices"], + ) + + @typecheck + def apply_mask(self, feature: Tensor, mask: Tensor, mask_ty: FeatureType) -> Tensor: + # override masking behavior - just return the unmasked feature + return feature + + @typecheck + def _generate( + self, + token_exists_mask: Bool[Tensor, "b n"], + atom_token_index: Int[Tensor, "b a"], + atom_covalent_bond_indices: list[ + tuple[Int[Tensor, "bonds"], Int[Tensor, "bonds"]] + ], + ) -> Tensor: + token_pair_mask = und_self(token_exists_mask, "b i, b j -> b i j") + bond_feature = torch.zeros_like(token_pair_mask.float()) + + for batch_idx, (left_indices, right_indices) in enumerate( + atom_covalent_bond_indices + ): + # convert from atom index to token index + left_token_indices = torch.gather( + atom_token_index[batch_idx], dim=0, index=left_indices + ) + right_token_indices = torch.gather( + atom_token_index[batch_idx], dim=0, index=right_indices + ) + bond_feature[batch_idx][left_token_indices, right_token_indices] = 1 + + return self.make_feature(bond_feature.unsqueeze(-1)) diff --git a/chai_lab/data/io/cif_utils.py b/chai_lab/data/io/cif_utils.py index 1a71320..db6c11e 100644 --- a/chai_lab/data/io/cif_utils.py +++ b/chai_lab/data/io/cif_utils.py @@ -8,7 +8,14 @@ import gemmi import modelcif import torch -from ihm import ChemComp, DNAChemComp, LPeptideChemComp, NonPolymerChemComp, RNAChemComp +from ihm import ( + ChemComp, + DNAChemComp, + LPeptideChemComp, + NonPolymerChemComp, + RNAChemComp, + SaccharideChemComp, +) from modelcif import Assembly, AsymUnit, Entity, dumper, model from torch import Tensor @@ -93,6 +100,8 @@ def _to_chem_component(res_name_3: str, entity_type: int): case EntityType.LIGAND.value: code = res_name_3 return NonPolymerChemComp(res_name_3) + case EntityType.MANUAL_GLYCAN.value: + return SaccharideChemComp(id=res_name_3, name=res_name_3) case EntityType.PROTEIN.value: code = restype_3to1.get(res_name_3, res_name_3) one_letter_code = gemmi.find_tabulated_residue(res_name_3).one_letter_code @@ -105,7 +114,7 @@ def _to_chem_component(res_name_3: str, entity_type: int): code = res_name_3 return RNAChemComp(res_name_3, code, code_canonical=code) case _: - raise NotImplementedError + raise NotImplementedError(f"Cannot handle entity type: {entity_type}") def sequence_to_chem_comps(sequence: list[str], entity_type: int) -> list[ChemComp]: diff --git a/chai_lab/data/parsing/glycans.py b/chai_lab/data/parsing/glycans.py new file mode 100644 index 0000000..9621a89 --- /dev/null +++ b/chai_lab/data/parsing/glycans.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024 Chai Discovery, Inc. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for details. + +"""Parsing logic for glycans. + +- Each sugar in the glycan gets a distinct residue + - Each residue has an increasing label seq + +Example glycan strings: +MAN(6-1 FUC)(4-1 MAN) +""" + +import re +from functools import lru_cache + +from attr import dataclass + +from chai_lab.data import residue_constants as rc +from chai_lab.data.parsing.structure.residue import Residue + + +@dataclass(frozen=True) +class GlycosidicBond: + src_sugar_index: int # 0-indexed + dst_sugar_index: int # 0-indexed + src_atom: int # 1-indexed + dst_atom: int # 1-indexed + + def __post_init__(self): + assert self.src_sugar_index != self.dst_sugar_index + assert self.src_atom > 0 and self.dst_atom > 0 + + @property + def src_atom_name(self) -> str: + return f"C{self.src_atom}" + + @property + def dst_atom_name(self) -> str: + return f"C{self.dst_atom}" + + +@lru_cache(maxsize=32) +def _glycan_string_to_sugars_and_bonds( + glycan_string: str, +) -> tuple[list[str], list[GlycosidicBond]]: + """Parses the glycan string to its constituent sugars and bonds.""" + sugars: list[str] = [] # Tracks all sugars + parent_sugar_idx: list[int] = [] # Tracks the parent sugar for bond formation + bonds: list[GlycosidicBond] = [] + open_count, closed_count = 0, 0 + for i in range(len(glycan_string)): + char = glycan_string[i] + if char == " ": + continue + if char == "(": + open_count += 1 + continue + if char == ")": + closed_count += 1 + parent_sugar_idx.pop() # Remove + continue + chunk = glycan_string[i : i + 3] + if re.match(r"[A-Z]{3}", chunk): + sugars.append(chunk) + parent_sugar_idx.append(len(sugars) - 1) # latest sugar + elif re.match(r"[1-6]{1}-[1-6]{1}", chunk): + s, d = chunk.split("-") + assert parent_sugar_idx + bonds.append( + GlycosidicBond( + src_sugar_index=parent_sugar_idx[-1], + dst_sugar_index=len(sugars), # Anticipate next + src_atom=int(s), + dst_atom=int(d), + ) + ) + assert open_count == closed_count + return sugars, bonds + + +def glycan_string_residues(glycan_string: str) -> list[Residue]: + sugars, _bonds = _glycan_string_to_sugars_and_bonds(glycan_string) + return [ + Residue( + name=sugar, + label_seq=i + 1, + restype=rc.residue_types_with_nucleotides_order["X"], + residue_index=i, + is_missing=False, + b_factor_or_plddt=0.0, + conformer_data=None, + is_covalent_bonded=True, + ) + for i, sugar in enumerate(sugars) + ] diff --git a/chai_lab/data/parsing/input_validation.py b/chai_lab/data/parsing/input_validation.py index 091e295..c1de044 100644 --- a/chai_lab/data/parsing/input_validation.py +++ b/chai_lab/data/parsing/input_validation.py @@ -74,5 +74,5 @@ def identify_potential_entity_types(sequence: str) -> list[EntityType]: ascii_symbols = string.ascii_letters + string.digits + ".-+=#$%:/\\[]()<>@" if set.issubset(set(sequence.upper()), set(ascii_symbols)): - possible_entity_types.append(EntityType.LIGAND) + possible_entity_types.extend([EntityType.LIGAND, EntityType.MANUAL_GLYCAN]) return possible_entity_types diff --git a/chai_lab/data/parsing/restraints.py b/chai_lab/data/parsing/restraints.py index 2b1bbab..294c745 100644 --- a/chai_lab/data/parsing/restraints.py +++ b/chai_lab/data/parsing/restraints.py @@ -90,22 +90,25 @@ def __post_init__(self): @property def res_idxA_name(self) -> str: """Single-char name of residue A.""" - return self.res_idxA[0] + return self.res_idxA[0] if self.res_idxA else "" @property def res_idxA_pos(self) -> int: - """1-indexed position of residue A.""" - return int(self.res_idxA[1:]) + """1-indexed position of residue A; defaults to 1 if not given.""" + # NOTE 1 default is because 1 is the minimum index under 1-indexing + s = self.res_idxA[1:] + return int(s) if s else 1 @property def res_idxB_name(self) -> str: """Single-char name of residue B.""" - return self.res_idxB[0] + return self.res_idxB[0] if self.res_idxB else "" @property def res_idxB_pos(self) -> int: - """1-indexed position of residue B.""" - return int(self.res_idxB[1:]) + """1-indexed position of residue B; defaults to 1 if not given.""" + s = self.res_idxB[1:] + return int(s) if s else 1 def to_table_entry(self) -> dict[str, str | float]: """Format as table entry, sans leading restraint_id column.""" diff --git a/chai_lab/data/parsing/structure/all_atom_entity_data.py b/chai_lab/data/parsing/structure/all_atom_entity_data.py index 270bc0c..3758d49 100644 --- a/chai_lab/data/parsing/structure/all_atom_entity_data.py +++ b/chai_lab/data/parsing/structure/all_atom_entity_data.py @@ -34,6 +34,7 @@ class AllAtomEntityData: entity_type: EntityType subchain_id: str is_d_polypeptide: bool = False # NOTE (mostly) exists for eval set construction + original_record: str = "" # NOTE for glycan parsing def __post_init__(self): assert ( diff --git a/chai_lab/data/parsing/structure/entity_type.py b/chai_lab/data/parsing/structure/entity_type.py index f8c1950..f681125 100644 --- a/chai_lab/data/parsing/structure/entity_type.py +++ b/chai_lab/data/parsing/structure/entity_type.py @@ -16,3 +16,4 @@ class EntityType(Enum): POLYMER_HYBRID = 4 WATER = 5 UNKNOWN = 6 + MANUAL_GLYCAN = 7 # NOTE glycan parsing diff --git a/chai_lab/data/parsing/structure/residue.py b/chai_lab/data/parsing/structure/residue.py index ad03d9d..a95a2d0 100644 --- a/chai_lab/data/parsing/structure/residue.py +++ b/chai_lab/data/parsing/structure/residue.py @@ -77,6 +77,7 @@ class Residue: b_factor_or_plddt: float conformer_data: ConformerData | None smiles: str | None = None + is_covalent_bonded: bool = False def get_restype( diff --git a/examples/glycosylation/1ac5.fasta b/examples/glycosylation/1ac5.fasta new file mode 100644 index 0000000..3e952e0 --- /dev/null +++ b/examples/glycosylation/1ac5.fasta @@ -0,0 +1,6 @@ +>protein|1AC5 +LPSSEEYKVAYELLPGLSEVPDPSNIPQMHAGHIPLRSEDADEQDSSDLEYFFWKFTNNDSNGNVDRPLIIWLNGGPGCSSMDGALVESGPFRVNSDGKLYLNEGSWISKGDLLFIDQPTGTGFSVEQNKDEGKIDKNKFDEDLEDVTKHFMDFLENYFKIFPEDLTRKIILSGESYAGQYIPFFANAILNHNKFSKIDGDTYDLKALLIGNGWIDPNTQSLSYLPFAMEKKLIDESNPNFKHLTNAHENCQNLINSASTDEAAHFSYQECENILNLLLSYTRESSQKGTADCLNMYNFNLKDSYPSCGMNWPKDISFVSKFFSTPGVIDSLHLDSDKIDHWKECTNSVGTKLSNPISKPSIHLLPGLLESGIEIVLFNGDKDLICNNKGVLDTIDNLKWGGIKGFSDDAVSFDWIHKSKSTDDSEEFSGYVKYDRNLTFVSVYNASHMVPFDKSLVSRGIVDIYSNDVMIIDNNGKNVMITT +>glycan|two-sugar +NAG(1-4 NAG) +>glycan|one-sugar +NAG \ No newline at end of file diff --git a/examples/glycosylation/README.md b/examples/glycosylation/README.md new file mode 100644 index 0000000..f9ca907 --- /dev/null +++ b/examples/glycosylation/README.md @@ -0,0 +1,66 @@ +# Working with bond restraints + +Chai-1 supports specifying covalent bonds as input, which specify covalent linkages between atoms in the folded complex. This is useful for specifying covalent modifications such as glycosylation events, which we demonstrate below, but can be generally used to specify arbitrary "non-canonical" bonds in a structure. + +A few notes: +- Chai-1 was not trained on disulfide bonds, and we have not evaluated whether specifying such bond information yields expected behaviors. +- These bond restraints should not be used to specify modified amino acids that already have an associated CCD code; for these examples, include the modified residue's CCD code in parentheses directly in the sequence in place of its canonical residue, e.g., `RKDES(MSE)EES` to specify a selenomethionine at the 6th position. + +## Glycans + +We adopt an abbreviated syntax for specifying glycans, which is best explained with a series of examples. + +### Single-ring glycan + +Let's say we have a glycan that is a single sugar ring, a 2-acetamido-2-deoxy-beta-D-glucopyranose. The [CCD code](https://www.rcsb.org/ligand/NAG) for this sugar is `NAG`, so we simply specify this sugar with the following fasta entry: +``` +>protein|example-protein +...N... +>glycan|example-single-sugar +NAG +``` + +Now, a glycan is also covalently bound to a residue; to specify this, we include the following line in our restraints file (see our documentation on restraints as well): + +chainA|res_idxA|chainB|res_idxB|connection_type|confidence|min_distance_angstrom|max_distance_angstrom|comment|restraint_id +|---|---|---|---|---|---|---|---|---|---| +A|N436@N|B|@C4|covalent|1.0|0.0|0.0|protein-glycan|bond1 + +Breaking this down, this specifies that the within chain A (the first entry in the fasta), the "N" residue at the 436-th position (1-indexed) as indicated by the "N436" prefix is bound, via its nitrogen "N" atom as indicated by the "@N" suffix, to the C4 atom in the first glycan ("@C4"). Ring numbering follows standard glycan ring number schemas. For other ligands, you will need check how the specific version of `rdkit` that we use in `chai-lab` (run `uv pip list | grep rdkit` for version) assigns atom names and use the same atom names to specify your bonds. In addition, note that the min and max distance fields are ignored, as is the confidence field. + + +### Multi-ring glycan + +Working through a more complex example, let's say we have a two-ring ligand such as that shown in the PDB structure [1AC5](https://www.rcsb.org/structure/1ac5). We introduce syntax for specifying bonds **within glycans** within the fasta record as such: + +``` +>protein|example-protein +...N... +>glycan|example-dual-sugar +NAG(1-4 NAG) +``` + +This syntax specifies that the root of the glycan is the leading `NAG` ring. The parentheses indicate that we are attaching another molecule to the ring directly preceding the parentheses. The `1-4` syntax "draws" a bond between the C1 atom of the previous "root" `NAG` and the C4 atom of the subsequent `NAG` ring. To specify how this glycan ought to be connected to the protein, we again use the restraints file to specify a residue and atom to which the glycan is bound, and the carbon atom within the root glycan ring that is bound. + +chainA|res_idxA|chainB|res_idxB|connection_type|confidence|min_distance_angstrom|max_distance_angstrom|comment|restraint_id +|---|---|---|---|---|---|---|---|---|---| +A|N436@N|B|@C4|covalent|1.0|0.0|0.0|protein-glycan|bond1 + +You can chain this syntax to create longer ligands: +``` +>glycan|4-NAG-in-a-linear-chain +NAG(1-4 NAG(1-4 NAG(1-4 NAG))) +``` + +...and to create branched ligands +``` +>glycan|branched-glycan +NAG(1-4 NAG(1-4 NAG))(3-4 MAN) +``` +This branched example has a root `NAG` ring with a branch with two more `NAG` rings and a branch with a single `MAN` ring. For additional examples, please refer to the examples tested in the `tests/test_glycans.py` test file. + +### Example + +We have included an example of how glycans can be specified under `predict_glycosylated.py` in this directory, along with its corresponding `bonds.restraints` csv file. This example is based on the PDB structure [1AC5](https://www.rcsb.org/structure/1ac5). The predicted structrue (colored, glycans in purple and orange, protein in green) from this script should look like the following when aligned with the ground truth 1AC5 structure (gray): + + diff --git a/examples/glycosylation/bonds.restraints b/examples/glycosylation/bonds.restraints new file mode 100644 index 0000000..d48b6ac --- /dev/null +++ b/examples/glycosylation/bonds.restraints @@ -0,0 +1,3 @@ +chainA,res_idxA,chainB,res_idxB,connection_type,confidence,min_distance_angstrom,max_distance_angstrom,comment,restraint_id +A,N437@N,B,@C4,covalent,1.0,0.0,0.0,protein-glycan,bond1 +A,N445@N,C,@C4,covalent,1.0,0.0,0.0,protein-glycan,bond2 \ No newline at end of file diff --git a/examples/glycosylation/output.png b/examples/glycosylation/output.png new file mode 100644 index 0000000..aa076df Binary files /dev/null and b/examples/glycosylation/output.png differ diff --git a/examples/glycosylation/predict_glycosylated.py b/examples/glycosylation/predict_glycosylated.py new file mode 100644 index 0000000..4734b77 --- /dev/null +++ b/examples/glycosylation/predict_glycosylated.py @@ -0,0 +1,27 @@ +import logging +import shutil +from pathlib import Path + +from chai_lab.chai1 import run_inference + +logging.basicConfig(level=logging.INFO) + +# Inference expects an empty directory; enforce this +output_dir = Path("/workspaces/chai-lab/tmp/outputs") +if output_dir.exists(): + logging.warning(f"Removing old output directory: {output_dir}") + shutil.rmtree(output_dir) +output_dir.mkdir(exist_ok=True, parents=True) + +candidates = run_inference( + fasta_file=Path(__file__).with_name("1ac5.fasta"), + output_dir=output_dir, + constraint_path=Path(__file__).with_name("bonds.restraints"), + num_trunk_recycles=3, + num_diffn_timesteps=200, + seed=42, + device="cuda:0", + use_esm_embeddings=True, +) + +cif_paths = candidates.cif_paths diff --git a/examples/predict_structure.py b/examples/predict_structure.py index 12b4d72..a83d10b 100644 --- a/examples/predict_structure.py +++ b/examples/predict_structure.py @@ -1,3 +1,5 @@ +import logging +import shutil from pathlib import Path import numpy as np @@ -26,7 +28,12 @@ fasta_path = Path("/tmp/example.fasta") fasta_path.write_text(example_fasta) +# Inference expects an empty directory; enforce this output_dir = Path("/tmp/outputs") +if output_dir.exists(): + logging.warning(f"Removing old output directory: {output_dir}") + shutil.rmtree(output_dir) +output_dir.mkdir(exist_ok=True) candidates = run_inference( fasta_file=fasta_path, diff --git a/tests/test_glycans.py b/tests/test_glycans.py new file mode 100644 index 0000000..7c3d646 --- /dev/null +++ b/tests/test_glycans.py @@ -0,0 +1,44 @@ +# Copyright (c) 2024 Chai Discovery, Inc. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for details. +from chai_lab.data.parsing.glycans import _glycan_string_to_sugars_and_bonds + + +def test_complex_parsing(): + glycan = "MAN(6-1 FUC)(4-1 MAN(6-1 MAN(6-1 MAN)))" + sugars, bonds = _glycan_string_to_sugars_and_bonds(glycan) + assert len(sugars) == 5 + + bond1, bond2, bond3, bond4 = bonds + + assert bond1.src_sugar_index == 0 + assert bond1.dst_sugar_index == 1 + assert bond2.src_sugar_index == 0 + assert bond2.dst_sugar_index == 2 + assert bond3.src_sugar_index == 2 + assert bond3.src_sugar_index == 2 + assert bond3.dst_sugar_index == 3 + assert bond4.src_sugar_index == 3 + assert bond4.dst_sugar_index == 4 + + +def test_complex_parsing_2(): + glycan = ( + "MAN(4-1 FUC(4-1 MAN)(6-1 FUC(4-1 MAN)))(6-1 MAN(6-1 MAN(4-1 MAN)(6-1 FUC)))" + ) + sugars, bonds = _glycan_string_to_sugars_and_bonds(glycan) + assert len(sugars) == 9 + + expected_bonds = [ + (0, 1), + (1, 2), + (1, 3), + (3, 4), + (0, 5), + (5, 6), + (6, 7), + (6, 8), + ] + for (expected_src, expected_dst), bond in zip(expected_bonds, bonds, strict=True): + assert bond.src_sugar_index == expected_src + assert bond.dst_sugar_index == expected_dst