diff --git a/chai_lab/data/dataset/msas/colabfold.py b/chai_lab/data/dataset/msas/colabfold.py index ffca923..c5c3142 100644 --- a/chai_lab/data/dataset/msas/colabfold.py +++ b/chai_lab/data/dataset/msas/colabfold.py @@ -16,8 +16,9 @@ from tqdm import tqdm from chai_lab import __version__ -from chai_lab.data.parsing.fasta import read_fasta +from chai_lab.data.parsing.fasta import Fasta, read_fasta from chai_lab.data.parsing.msas.aligned_pqt import expected_basename, hash_sequence +from chai_lab.data.parsing.msas.data_source import MSADataSource logger = logging.getLogger(__name__) @@ -26,7 +27,7 @@ ) -# N.B. this code is copied from https://github.com/sokrypton/ColabFold +# N.B. this function (and this function only) is copied from https://github.com/sokrypton/ColabFold # and follows the license in that repository @typing.no_type_check # Original ColabFold code was not well typed def _run_mmseqs2( @@ -41,6 +42,7 @@ def _run_mmseqs2( host_url="https://api.colabfold.com", user_agent: str = "", ) -> list[str] | tuple[list[str], list[str]]: + """Return a block of a3m lines for each of the input sequences in x.""" submission_endpoint = "ticket/pair" if use_pairing else "ticket/msa" headers = {} @@ -342,19 +344,26 @@ def download(ID, path): return (a3m_lines, template_paths) if use_templates else a3m_lines +def _is_padding_msa_row(sequence: str) -> bool: + """Check if the given MSA sequence is a a padding sequence.""" + seq_chars = set(sequence) + return len(seq_chars) == 1 and seq_chars.pop() == "-" + + def generate_colabfold_msas( protein_seqs: list[str], msa_dir: Path, msa_server_url: str, + write_a3m_to_msa_dir: bool = False, # Useful for manual inspection + debugging ): """ Generate MSAs using the ColabFold (https://github.com/sokrypton/ColabFold) server. No-op if no protein sequences are given. - N.B. the MSAs in our technical report were generated using jackhmmer, not + N.B.: + - the MSAs in our technical report were generated using jackhmmer, not ColabFold, so we would expect some difference in results. - - This implementation also relies on ColabFold's chain pairing algorithm + - this implementation relies on ColabFold's chain pairing algorithm rather than using Chai-1's own algorithm, which could also lead to differences in results. @@ -369,52 +378,108 @@ def generate_colabfold_msas( with tempfile.TemporaryDirectory() as tmp_dir_path: tmp_dir = Path(tmp_dir_path) + mmseqs_paired_dir = tmp_dir / "mmseqs_paired" + mmseqs_paired_dir.mkdir() + mmseqs_dir = tmp_dir / "mmseqs" mmseqs_dir.mkdir() - a3ms_dir = tmp_dir / "a3ms" + a3ms_dir = (tmp_dir if not write_a3m_to_msa_dir else msa_dir) / "a3ms" a3ms_dir.mkdir() # Generate MSAs for each protein chain logger.info(f"Running MSA generation for {len(protein_seqs)} protein sequences") - msas = _run_mmseqs2( + + # In paired mode, mmseqs2 returns paired a3ms where all a3ms have the same number of rows + # and each row is already paired to have the same species. As such, we insert pairing key + # as the i-th index of the sequence so long as it isn't a padding sequence (all -) + paired_msas: list[str] + if len(protein_seqs) > 1: + paired_msas = _run_mmseqs2( + protein_seqs, + mmseqs_paired_dir, + use_pairing=True, + host_url=msa_server_url, + user_agent=f"chai-lab/{__version__} feedback@chaidiscovery.com", + ) + else: + # If we only have a single protein chain, there are no paired MSAs by definition + paired_msas = [""] * len(protein_seqs) + + # MSAs without pairing logic attached; may include sequences not contained in the paired MSA + # Needs a second call as the colabfold server returns either paired or unpaired, not both + per_chain_msas = _run_mmseqs2( protein_seqs, mmseqs_dir, - # N.B. we can set this to False to disable pairing - use_pairing=len(protein_seqs) > 1, + use_pairing=False, host_url=msa_server_url, user_agent=f"chai-lab/{__version__} feedback@chaidiscovery.com", ) - assert isinstance(msas, list) # Process the MSAs into our internal format - for protein_seq, msa in zip(protein_seqs, msas, strict=True): - # Write out an A3M file - a3m_path = a3ms_dir / f"{hash_sequence(protein_seq.upper())}.a3m" - a3m_path.write_text(msa) - - # Convert the A3M file into aligned parquet files - msa_fasta = read_fasta(a3m_path) - headers, msa_seqs = zip(*msa_fasta) + for protein_seq, pair_msa, single_msa in zip( + protein_seqs, paired_msas, per_chain_msas, strict=True + ): + # Write out an A3M file for both + hkey = hash_sequence(protein_seq.upper()) + pair_a3m_path = a3ms_dir / f"{hkey}.pair.a3m" + pair_a3m_path.write_text(pair_msa) + single_a3m_path = a3ms_dir / f"{hkey}.single.a3m" + single_a3m_path.write_text(single_msa) + + ## Convert the A3M file into aligned parquet files + # Set the pairing key as the ith-index in the sequences, skip over sequences that have + # been inserted as padding as our internal pairing logic will match on pairing key. + paired_fasta: list[tuple[str, str, str]] = [ + (str(pairkey), record.header, record.sequence) + for pairkey, record in enumerate(read_fasta(pair_a3m_path)) + if not _is_padding_msa_row(record.sequence) + ] + pairing_key, paired_headers, paired_msa_seqs = ( + zip(*paired_fasta) if paired_fasta else ((), (), ()) + ) + unique_paired_msa_seqs = set(paired_msa_seqs) + + # Non-paired MSA sequences that weren't already covered in the paired MSA; skip header + single_fasta: list[Fasta] = [ + record + for i, record in enumerate(read_fasta(single_a3m_path)) + if ( + i > 0 + and not _is_padding_msa_row(record.sequence) + and record.sequence not in unique_paired_msa_seqs + ) + ] + single_headers = [record.header for record in single_fasta] + single_msa_seqs = [record.sequence for record in single_fasta] + # Create null pairing keys for each of the entries in the single MSA seq + single_null_pair_keys = [""] * len(single_msa_seqs) # This shouldn't have much of an effect on the model, but we make # a best effort to synthesize a source database anyway + # NOTE we already dropped the query row from the single MSAs so no need to slice source_databases = ["query"] + [ - "uniref90" if h.startswith("UniRef") else "bfd_uniclust" - for h in headers[1:] + ( + MSADataSource.UNIREF90.value + if h.startswith("UniRef") + else MSADataSource.BFD_UNICLUST.value + ) + for h in (list(paired_headers) + single_headers)[1:] ] + # Combine information across paired and single hits + all_sequences = list(paired_msa_seqs) + single_msa_seqs + all_pairing_keys = list(pairing_key) + single_null_pair_keys + assert ( + len(all_sequences) == len(all_pairing_keys) == len(source_databases) + ), f"Mismatched lengths: {len(all_sequences)=} {len(all_pairing_keys)=} {len(source_databases)=}" + # Map the MSAs to our internal format aligned_df = pd.DataFrame( data=dict( - sequence=msa_seqs, + sequence=all_sequences, source_database=source_databases, - # ColabFold does not return taxonomies from its API, so we - # can't rely on our internal chain pairing logic. As an - # alternative, we could disable ColabFold pairing and rely - # on a mapping from sequence ~> taxonomy, which would allow - # us to use our internal pairing logic. - pairing_key="", + pairing_key=all_pairing_keys, comment="", ), ) diff --git a/chai_lab/data/dataset/msas/preprocess.py b/chai_lab/data/dataset/msas/preprocess.py index efead11..f2dc8b9 100644 --- a/chai_lab/data/dataset/msas/preprocess.py +++ b/chai_lab/data/dataset/msas/preprocess.py @@ -17,6 +17,7 @@ _UKEY_FOR_QUERY = (-999, -999) logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) def merge_main_msas_by_chain(msas: list[MSAContext]) -> MSAContext: @@ -120,7 +121,7 @@ def pair_and_merge_msas(msas: list[MSAContext]) -> MSAContext: selected_msa = msa.take_rows_with_padding(all_rowids) logger.info( - f"Loaded (paired in includes query sequence): " + f"Loaded (paired includes query sequence): " f"{n_paired_msa=} {n_unpaired_msa=} out of {msa.depth=} " )