diff --git a/chai_lab/data/dataset/msas/load.py b/chai_lab/data/dataset/msas/load.py index 5e177ed..7b29045 100644 --- a/chai_lab/data/dataset/msas/load.py +++ b/chai_lab/data/dataset/msas/load.py @@ -20,6 +20,7 @@ parse_aligned_pqt_to_msa_context, ) from chai_lab.data.parsing.msas.data_source import MSADataSource +from chai_lab.data.parsing.structure.entity_type import EntityType logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -43,14 +44,14 @@ def get_msa_contexts( # MSAs are constructed based on sequence, so use the unique sequences present # in input chains to determine the MSAs that need to be loaded - - def get_msa_contexts_for_seq(seq) -> MSAContext: + def get_msa_contexts_for_seq(seq: str, etype: EntityType) -> MSAContext: path = msa_directory / expected_basename(seq) - if not path.is_file(): - if seq != "X": - # Don't warn for the special "X" sequence + # If the MSA is missing, or the query is not a protein, return an empty MSA + if not path.is_file() or etype != EntityType.PROTEIN: + if etype == EntityType.PROTEIN: + # Warn for proteins that have missing MSAs logger.warning(f"No MSA found for sequence: {seq}") - [tokenized_seq] = tokenize_sequences_to_arrays([seq])[0] + [tokenized_seq], _ = tokenize_sequences_to_arrays([seq]) return MSAContext.create_single_seq( MSADataSource.QUERY, tokens=torch.from_numpy(tokenized_seq) ) @@ -61,9 +62,9 @@ def get_msa_contexts_for_seq(seq) -> MSAContext: # For each chain, either fetch the corresponding MSA or create an empty MSA if it is missing # + reindex to handle residues that are tokenized per-atom (this also crops if necessary) msa_contexts = [ - get_msa_contexts_for_seq(chain.entity_data.sequence)[ - :, chain.structure_context.token_residue_index - ] + get_msa_contexts_for_seq( + seq=chain.entity_data.sequence, etype=chain.entity_data.entity_type + )[:, chain.structure_context.token_residue_index] for chain in chains ] diff --git a/chai_lab/main.py b/chai_lab/main.py index 124a94d..8132d65 100644 --- a/chai_lab/main.py +++ b/chai_lab/main.py @@ -10,6 +10,8 @@ from chai_lab.chai1 import run_inference +logging.basicConfig(level=logging.INFO) + CITATION = """ @article{Chai-1-Technical-Report, title = {Chai-1: Decoding the molecular interactions of life}, @@ -38,5 +40,4 @@ def cli(): if __name__ == "__main__": - logging.basicConfig(level=logging.INFO) cli()