From 39aec6d9d2c113452d66eea66a87079a7e6f6764 Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Wed, 4 Dec 2024 22:33:14 +0000 Subject: [PATCH 1/6] Typo in logging message --- chai_lab/data/dataset/msas/preprocess.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/chai_lab/data/dataset/msas/preprocess.py b/chai_lab/data/dataset/msas/preprocess.py index efead11..f2dc8b9 100644 --- a/chai_lab/data/dataset/msas/preprocess.py +++ b/chai_lab/data/dataset/msas/preprocess.py @@ -17,6 +17,7 @@ _UKEY_FOR_QUERY = (-999, -999) logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) def merge_main_msas_by_chain(msas: list[MSAContext]) -> MSAContext: @@ -120,7 +121,7 @@ def pair_and_merge_msas(msas: list[MSAContext]) -> MSAContext: selected_msa = msa.take_rows_with_padding(all_rowids) logger.info( - f"Loaded (paired in includes query sequence): " + f"Loaded (paired includes query sequence): " f"{n_paired_msa=} {n_unpaired_msa=} out of {msa.depth=} " ) From 52a99ba7898e90810451110b3f49308885624fd3 Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Wed, 4 Dec 2024 22:40:29 +0000 Subject: [PATCH 2/6] Call colabfold mmseqs server once for paired MSAs, and once for unpaired MSAs --- chai_lab/data/dataset/msas/colabfold.py | 106 +++++++++++++++++++----- 1 file changed, 83 insertions(+), 23 deletions(-) diff --git a/chai_lab/data/dataset/msas/colabfold.py b/chai_lab/data/dataset/msas/colabfold.py index ffca923..4fbea8f 100644 --- a/chai_lab/data/dataset/msas/colabfold.py +++ b/chai_lab/data/dataset/msas/colabfold.py @@ -41,6 +41,7 @@ def _run_mmseqs2( host_url="https://api.colabfold.com", user_agent: str = "", ) -> list[str] | tuple[list[str], list[str]]: + """Return a block of a3m lines for each of the input sequences in x.""" submission_endpoint = "ticket/pair" if use_pairing else "ticket/msa" headers = {} @@ -342,19 +343,26 @@ def download(ID, path): return (a3m_lines, template_paths) if use_templates else a3m_lines +def _is_padding_msa_row(sequence: str) -> bool: + """Check if the given MSA sequence is a a padding sequence.""" + seq_chars = set(sequence) + return len(seq_chars) == 1 and seq_chars.pop() == "-" + + def generate_colabfold_msas( protein_seqs: list[str], msa_dir: Path, msa_server_url: str, + write_a3m_to_msa_dir: bool = False, # Useful for manual inspection + debugging ): """ Generate MSAs using the ColabFold (https://github.com/sokrypton/ColabFold) server. No-op if no protein sequences are given. - N.B. the MSAs in our technical report were generated using jackhmmer, not + N.B.: + - the MSAs in our technical report were generated using jackhmmer, not ColabFold, so we would expect some difference in results. - - This implementation also relies on ColabFold's chain pairing algorithm + - this implementation relies on ColabFold's chain pairing algorithm rather than using Chai-1's own algorithm, which could also lead to differences in results. @@ -369,52 +377,104 @@ def generate_colabfold_msas( with tempfile.TemporaryDirectory() as tmp_dir_path: tmp_dir = Path(tmp_dir_path) + mmseqs_paired_dir = tmp_dir / "mmseqs_paired" + mmseqs_paired_dir.mkdir() + mmseqs_dir = tmp_dir / "mmseqs" mmseqs_dir.mkdir() - a3ms_dir = tmp_dir / "a3ms" + a3ms_dir = (tmp_dir if not write_a3m_to_msa_dir else msa_dir) / "a3ms" a3ms_dir.mkdir() # Generate MSAs for each protein chain logger.info(f"Running MSA generation for {len(protein_seqs)} protein sequences") - msas = _run_mmseqs2( + + # In paired mode, mmseqs2 returns paired a3ms where all a3ms have the same number of rows + # and each row is already paired to have the same species. As such, we insert pairing key + # as the i-th index of the sequence so long as it isn't a padding sequence (all -) + if len(protein_seqs) > 1: + paired_msas = _run_mmseqs2( + protein_seqs, + mmseqs_paired_dir, + use_pairing=True, + host_url=msa_server_url, + user_agent=f"chai-lab/{__version__} feedback@chaidiscovery.com", + ) + else: + # If we only have a single protein chain, there are no paired MSAs by definition + paired_msas = ["" for _ in protein_seqs] + assert isinstance(paired_msas, list) + + # MSAs without pairing logic attached; may include sequences not contained in the paired MSA + # Needs a second call as the colabfold server returns either paired or unpaired, not both + per_chain_msas = _run_mmseqs2( protein_seqs, mmseqs_dir, - # N.B. we can set this to False to disable pairing - use_pairing=len(protein_seqs) > 1, + use_pairing=False, host_url=msa_server_url, user_agent=f"chai-lab/{__version__} feedback@chaidiscovery.com", ) - assert isinstance(msas, list) # Process the MSAs into our internal format - for protein_seq, msa in zip(protein_seqs, msas, strict=True): - # Write out an A3M file - a3m_path = a3ms_dir / f"{hash_sequence(protein_seq.upper())}.a3m" - a3m_path.write_text(msa) + for protein_seq, pair_msa, single_msa in zip( + protein_seqs, paired_msas, per_chain_msas, strict=True + ): + # Write out an A3M file for both + hkey = hash_sequence(protein_seq.upper()) + pair_a3m_path = a3ms_dir / f"{hkey}.pair.a3m" + pair_a3m_path.write_text(pair_msa) + single_a3m_path = a3ms_dir / f"{hkey}.single.a3m" + single_a3m_path.write_text(single_msa) + + ## Convert the A3M file into aligned parquet files + # Set the pairing key as the ith-index in the sequences, skip over sequences that have + # been inserted as padding as our internal pairing logic will match on pairing key. + paired_fasta: list[tuple[int, str, str]] = [ + (pairkey, record.header, record.sequence) + for pairkey, record in enumerate(read_fasta(pair_a3m_path)) + if not _is_padding_msa_row(record.sequence) + ] + pairing_key, paired_headers, paired_msa_seqs = ( + zip(*paired_fasta) if paired_fasta else ((), (), ()) + ) - # Convert the A3M file into aligned parquet files - msa_fasta = read_fasta(a3m_path) - headers, msa_seqs = zip(*msa_fasta) + # Non-paired MSA sequences that weren't already covered in the paired MSA; skip header + single_fasta: list[tuple[str, str]] = [ + (record.header, record.sequence) + for i, record in enumerate(read_fasta(single_a3m_path)) + if ( + i > 0 + and not _is_padding_msa_row(record.sequence) + and record.sequence not in set(paired_msa_seqs) + ) + ] + single_headers, single_msa_seqs = ( + zip(*single_fasta) if single_fasta else ((), ()) + ) + # Create null pairing keys for each of the entries in the single MSA seq + single_null_pair_keys = ["" for _ in range(len(single_msa_seqs))] # This shouldn't have much of an effect on the model, but we make # a best effort to synthesize a source database anyway + # NOTE we already dropped the query row from the single MSAs so no need to slice source_databases = ["query"] + [ "uniref90" if h.startswith("UniRef") else "bfd_uniclust" - for h in headers[1:] + for h in paired_headers[1:] + single_headers ] + # Combine information across paired and single hits + all_sequences = paired_msa_seqs + single_msa_seqs + all_pairing_keys = [str(k) for k in pairing_key] + single_null_pair_keys + assert ( + len(all_sequences) == len(all_pairing_keys) == len(source_databases) + ), f"Mismatched lengths: {len(all_sequences)=} {len(all_pairing_keys)=} {len(source_databases)=}" + # Map the MSAs to our internal format aligned_df = pd.DataFrame( data=dict( - sequence=msa_seqs, + sequence=all_sequences, source_database=source_databases, - # ColabFold does not return taxonomies from its API, so we - # can't rely on our internal chain pairing logic. As an - # alternative, we could disable ColabFold pairing and rely - # on a mapping from sequence ~> taxonomy, which would allow - # us to use our internal pairing logic. - pairing_key="", + pairing_key=all_pairing_keys, comment="", ), ) From 630566f90a8c3dfa2bb2b39a84264fac1946dd99 Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Wed, 4 Dec 2024 22:46:31 +0000 Subject: [PATCH 3/6] Update comment --- chai_lab/data/dataset/msas/colabfold.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chai_lab/data/dataset/msas/colabfold.py b/chai_lab/data/dataset/msas/colabfold.py index 4fbea8f..24b5e74 100644 --- a/chai_lab/data/dataset/msas/colabfold.py +++ b/chai_lab/data/dataset/msas/colabfold.py @@ -26,7 +26,7 @@ ) -# N.B. this code is copied from https://github.com/sokrypton/ColabFold +# N.B. this function (and this function only) is copied from https://github.com/sokrypton/ColabFold # and follows the license in that repository @typing.no_type_check # Original ColabFold code was not well typed def _run_mmseqs2( From 0b65b932e82197f5e59a6608200524eef582c0fe Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Wed, 4 Dec 2024 22:54:20 +0000 Subject: [PATCH 4/6] Fix issue with monomers and constructing source databases --- chai_lab/data/dataset/msas/colabfold.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chai_lab/data/dataset/msas/colabfold.py b/chai_lab/data/dataset/msas/colabfold.py index 24b5e74..872aef9 100644 --- a/chai_lab/data/dataset/msas/colabfold.py +++ b/chai_lab/data/dataset/msas/colabfold.py @@ -459,7 +459,7 @@ def generate_colabfold_msas( # NOTE we already dropped the query row from the single MSAs so no need to slice source_databases = ["query"] + [ "uniref90" if h.startswith("UniRef") else "bfd_uniclust" - for h in paired_headers[1:] + single_headers + for h in (paired_headers + single_headers)[1:] ] # Combine information across paired and single hits From 0073bea317d0f564c465c4255efa7501199643d8 Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Wed, 4 Dec 2024 23:00:44 +0000 Subject: [PATCH 5/6] Avoid using string literals --- chai_lab/data/dataset/msas/colabfold.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/chai_lab/data/dataset/msas/colabfold.py b/chai_lab/data/dataset/msas/colabfold.py index 872aef9..8e6688b 100644 --- a/chai_lab/data/dataset/msas/colabfold.py +++ b/chai_lab/data/dataset/msas/colabfold.py @@ -18,6 +18,7 @@ from chai_lab import __version__ from chai_lab.data.parsing.fasta import read_fasta from chai_lab.data.parsing.msas.aligned_pqt import expected_basename, hash_sequence +from chai_lab.data.parsing.msas.data_source import MSADataSource logger = logging.getLogger(__name__) @@ -458,7 +459,11 @@ def generate_colabfold_msas( # a best effort to synthesize a source database anyway # NOTE we already dropped the query row from the single MSAs so no need to slice source_databases = ["query"] + [ - "uniref90" if h.startswith("UniRef") else "bfd_uniclust" + ( + MSADataSource.UNIREF90.value + if h.startswith("UniRef") + else MSADataSource.BFD_UNICLUST.value + ) for h in (paired_headers + single_headers)[1:] ] From 3bf3246113b5a4f4ee1119cd579cf969a50f9d67 Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Thu, 5 Dec 2024 01:16:10 +0000 Subject: [PATCH 6/6] Address PR feedback --- chai_lab/data/dataset/msas/colabfold.py | 30 ++++++++++++------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/chai_lab/data/dataset/msas/colabfold.py b/chai_lab/data/dataset/msas/colabfold.py index 8e6688b..c5c3142 100644 --- a/chai_lab/data/dataset/msas/colabfold.py +++ b/chai_lab/data/dataset/msas/colabfold.py @@ -16,7 +16,7 @@ from tqdm import tqdm from chai_lab import __version__ -from chai_lab.data.parsing.fasta import read_fasta +from chai_lab.data.parsing.fasta import Fasta, read_fasta from chai_lab.data.parsing.msas.aligned_pqt import expected_basename, hash_sequence from chai_lab.data.parsing.msas.data_source import MSADataSource @@ -393,6 +393,7 @@ def generate_colabfold_msas( # In paired mode, mmseqs2 returns paired a3ms where all a3ms have the same number of rows # and each row is already paired to have the same species. As such, we insert pairing key # as the i-th index of the sequence so long as it isn't a padding sequence (all -) + paired_msas: list[str] if len(protein_seqs) > 1: paired_msas = _run_mmseqs2( protein_seqs, @@ -403,8 +404,7 @@ def generate_colabfold_msas( ) else: # If we only have a single protein chain, there are no paired MSAs by definition - paired_msas = ["" for _ in protein_seqs] - assert isinstance(paired_msas, list) + paired_msas = [""] * len(protein_seqs) # MSAs without pairing logic attached; may include sequences not contained in the paired MSA # Needs a second call as the colabfold server returns either paired or unpaired, not both @@ -430,30 +430,30 @@ def generate_colabfold_msas( ## Convert the A3M file into aligned parquet files # Set the pairing key as the ith-index in the sequences, skip over sequences that have # been inserted as padding as our internal pairing logic will match on pairing key. - paired_fasta: list[tuple[int, str, str]] = [ - (pairkey, record.header, record.sequence) + paired_fasta: list[tuple[str, str, str]] = [ + (str(pairkey), record.header, record.sequence) for pairkey, record in enumerate(read_fasta(pair_a3m_path)) if not _is_padding_msa_row(record.sequence) ] pairing_key, paired_headers, paired_msa_seqs = ( zip(*paired_fasta) if paired_fasta else ((), (), ()) ) + unique_paired_msa_seqs = set(paired_msa_seqs) # Non-paired MSA sequences that weren't already covered in the paired MSA; skip header - single_fasta: list[tuple[str, str]] = [ - (record.header, record.sequence) + single_fasta: list[Fasta] = [ + record for i, record in enumerate(read_fasta(single_a3m_path)) if ( i > 0 and not _is_padding_msa_row(record.sequence) - and record.sequence not in set(paired_msa_seqs) + and record.sequence not in unique_paired_msa_seqs ) ] - single_headers, single_msa_seqs = ( - zip(*single_fasta) if single_fasta else ((), ()) - ) + single_headers = [record.header for record in single_fasta] + single_msa_seqs = [record.sequence for record in single_fasta] # Create null pairing keys for each of the entries in the single MSA seq - single_null_pair_keys = ["" for _ in range(len(single_msa_seqs))] + single_null_pair_keys = [""] * len(single_msa_seqs) # This shouldn't have much of an effect on the model, but we make # a best effort to synthesize a source database anyway @@ -464,12 +464,12 @@ def generate_colabfold_msas( if h.startswith("UniRef") else MSADataSource.BFD_UNICLUST.value ) - for h in (paired_headers + single_headers)[1:] + for h in (list(paired_headers) + single_headers)[1:] ] # Combine information across paired and single hits - all_sequences = paired_msa_seqs + single_msa_seqs - all_pairing_keys = [str(k) for k in pairing_key] + single_null_pair_keys + all_sequences = list(paired_msa_seqs) + single_msa_seqs + all_pairing_keys = list(pairing_key) + single_null_pair_keys assert ( len(all_sequences) == len(all_pairing_keys) == len(source_databases) ), f"Mismatched lengths: {len(all_sequences)=} {len(all_pairing_keys)=} {len(source_databases)=}"