Skip to content

Commit

Permalink
Script to take ColabFold outputs and run Chai1 with the same MSAs and…
Browse files Browse the repository at this point in the history
… templates (#310)
  • Loading branch information
wukevin authored Feb 18, 2025
1 parent 172b666 commit d591369
Show file tree
Hide file tree
Showing 4 changed files with 222 additions and 3 deletions.
7 changes: 5 additions & 2 deletions chai_lab/data/parsing/fasta.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# See the LICENSE file for details.

import logging
from io import StringIO
from pathlib import Path
from typing import NamedTuple, Sequence

Expand Down Expand Up @@ -30,10 +31,12 @@ def fastas_to_str(fastas: Sequence[Fasta]) -> str:
return "".join(f">{fasta.header}\n{fasta.sequence}\n" for fasta in fastas)


def read_fasta(file_path: str | Path) -> list[Fasta]:
def read_fasta(file_path: str | Path | StringIO) -> list[Fasta]:
from Bio import SeqIO

fasta_sequences = SeqIO.parse(open(file_path), "fasta")
fasta_sequences = SeqIO.parse(
open(file_path) if isinstance(file_path, (str, Path)) else file_path, "fasta"
)
return [Fasta(fasta.description, str(fasta.seq)) for fasta in fasta_sequences]


Expand Down
22 changes: 22 additions & 0 deletions chai_lab/data/parsing/msas/a3m.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@
we do use expects a column of sequences in the same format that .a3m gives.
"""

import re
import string
from functools import lru_cache
from io import StringIO
from pathlib import Path
from typing import Final

import numba
import numpy as np

from chai_lab.data.parsing.fasta import Fasta, read_fasta
from chai_lab.data.residue_constants import residue_types_with_nucleotides_order

MAPPED_TOKEN_SKIP: Final[int] = -1
Expand Down Expand Up @@ -107,3 +111,21 @@ def tokenize_sequences_to_arrays(
out_deletions=out_deletions,
)
return out_sequences, out_deletions


def read_colabfold_a3m(fname: Path) -> dict[str, list[Fasta]]:
"""Returns mapping of MSA hits per identifier in the given a3m file.
The query line in each block of MSA hits is retained.
"""
text = fname.read_text()
retval: dict[str, list[Fasta]] = {}
for block in text.split("\x00"): # Splits on null byte
if not block:
continue
strio = StringIO(block)
hits = read_fasta(strio)
assert len(hits) > 0
assert re.match(r"^[0-9]{3}$", (query := hits[0].header))
retval[query] = hits
return retval
2 changes: 1 addition & 1 deletion chai_lab/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def get_asym_id_from_subchain_id(
subchain_id: str,
source_pdb_chain_id: UInt8[Tensor, "n_tokens 4"],
token_asym_id: Int[Tensor, "n"],
):
) -> int:
# encde the subchain ids and perform lookup in context features
chain_id_tensorcode = string_to_tensorcode(subchain_id, pad_to_length=4)
chain_id_tensorcode = chain_id_tensorcode.to(token_asym_id.device)
Expand Down
194 changes: 194 additions & 0 deletions scripts/stage_for_chai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# Copyright (c) 2024 Chai Discovery, Inc.
# Licensed under the Apache License, Version 2.0.
# See the LICENSE file for details.
"""
Given a output directory from a ColabFold run, traverses the directory structure and stage
the same MSA and templates to run through Chai1.
Some minimal example:
Given the following directory structure:
colab_out_dir/
- 4nnp_env/
- 4nnp_pairgreedy/
...
- sequences.csv (containing 4nnp as an id)
Run:
python stage_for_chai.py colab_out_dir chai_folder
This should create the following:
chai_folder/
- 4nnp/
- chai.fasta (input sequences for chai model)
- msas/ (contain the same sequences + pairing as colabfold writes)
- hash1.aligned.pqt
- hash2.aligned.pqt
- ...
all_template_hits.m8 (contains template hits for all chains)
Then, you can Chai on the files:
chai-lab fold chai_folder/4nnp/chai.fasta 4nnp_out --msa-directory chai_folder/4nnp/msas/ --template-hits-path chai_folder/4nnp/all_template_hits.m8
NOTE This preserves the pairing that ColabFold determines; this is NOT necessarily
the same as the pairing that occurs when using the --use-msa-server flag.
"""

import logging
from pathlib import Path

import pandas as pd
import typer

from chai_lab.data.io.cif_utils import get_chain_letter
from chai_lab.data.parsing.fasta import Fasta, write_fastas
from chai_lab.data.parsing.msas.a3m import read_colabfold_a3m
from chai_lab.data.parsing.msas.aligned_pqt import (
AlignedParquetModel,
expected_basename,
)
from chai_lab.data.parsing.msas.data_source import MSADataSource
from chai_lab.data.parsing.templates.m8 import parse_m8_file

app = typer.Typer(pretty_exceptions_enable=False)


def read_colabfold_inputs(fname: Path) -> dict[str, list[Fasta]]:
"""Extracts sequences from colabfold input table."""
df = pd.read_csv(fname, delimiter=",")
assert list(df.columns) == ["id", "sequence"]
retval: dict[str, list[Fasta]] = {}
for row in df.itertuples():
sequences: list[str] = row.sequence.split(":") # type: ignore
complex: list[Fasta] = [
Fasta(header=f"protein|{get_chain_letter(i)}", sequence=seq)
for i, seq in enumerate(sequences, start=1)
]
retval[row.id] = complex # type: ignore
return retval


def gather_colabfold_msas(
colabfold_out_dir: Path, identifier: str, output_folder: Path
) -> dict[str, str]:
"""Gathers MSAs generated by colabfold and writes them to the given output folder.
Returns mapping of colabfold generated identifiers -> sequences.
"""
output_folder.mkdir(parents=True, exist_ok=True)
paired_msa = read_colabfold_a3m(
colabfold_out_dir / f"{identifier}_pairgreedy/pair.a3m"
)
# The paired MSA should be the same number of rows for all
paired_lengths = set(len(v) for v in paired_msa.values())
assert len(paired_lengths) == 1
n_paired = paired_lengths.pop()
logging.info(f"[{identifier}] Colabfold paired {n_paired} MSAs")

# Read in also the single chain MSAs
uniref_msa = read_colabfold_a3m(colabfold_out_dir / f"{identifier}_env/uniref.a3m")

env_msa = read_colabfold_a3m(
colabfold_out_dir / f"{identifier}_env/bfd.mgnify30.metaeuk30.smag30.a3m"
)
assert set(uniref_msa.keys()) == set(env_msa.keys()) == set(paired_msa.keys())

retval: dict[str, str] = {}
for query in paired_msa.keys():
query_seq = uniref_msa[query][0].sequence
msa_rows = []
for i, row in enumerate(paired_msa[query]):
record = {
"sequence": row.sequence,
"source_database": (
MSADataSource.QUERY if i == 0 else MSADataSource.UNIREF90
).value,
"pairing_key": str(i) if i > 0 else "",
"comment": "null",
}
msa_rows.append(record)
for row in uniref_msa[query][1:]:
msa_rows.append(
{
"sequence": row.sequence,
"source_database": MSADataSource.UNIREF90.value,
"pairing_key": "",
"comment": "null",
}
)
for row in env_msa[query][1:]:
msa_rows.append(
{
"sequence": row.sequence,
"source_database": MSADataSource.BFD_UNICLUST.value,
"pairing_key": "",
"comment": "null",
}
)
table = pd.DataFrame.from_records(msa_rows)
AlignedParquetModel.validate(table)
table.to_parquet(output_folder / expected_basename(query_sequence=query_seq))
retval[query] = query_seq
return retval


def gather_colabfold_templates(
colabfold_out_dir: Path,
identifier: str,
chain_id_mapping: dict[str, str],
output_folder: Path,
) -> Path:
template_file = colabfold_out_dir / f"{identifier}_env" / "pdb70.m8"
assert template_file.is_file()
templates = parse_m8_file(template_file)
templates["query_id"] = templates["query_id"].apply(
lambda s: chain_id_mapping[str(s)]
)
outfile = output_folder / "all_template_hits.m8"
templates.to_csv(outfile, sep="\t", index=False, header=False)
return outfile


@app.command()
def main(colabfold_out_dir: Path, chai_dir: Path):
"""Takes a directory containing colabfold outputs and stages them for Chai1."""
csv_files = list(colabfold_out_dir.glob("*.csv"))
assert len(csv_files) == 1, f"Expected a single csv file but got {len(csv_files)}"
fasta_entries: dict[str, list[Fasta]] = read_colabfold_inputs(csv_files.pop())

for identifier, sequences in fasta_entries.items():
chai_out_folder = chai_dir / identifier
chai_out_folder.mkdir(parents=True, exist_ok=True)

# Gather MSAs
colabfold_id_to_seq = gather_colabfold_msas(
colabfold_out_dir=colabfold_out_dir,
identifier=identifier,
output_folder=chai_out_folder / "msas",
)
assert set(colabfold_id_to_seq.values()) == set([f.sequence for f in sequences])

# Build a mapping for each sequence in the input to the
colab_id_to_chai_id = {}
for colabfold_id, seq in colabfold_id_to_seq.items():
chai_seq_matches = [s for s in sequences if s.sequence == seq]
assert len(chai_seq_matches)
colab_id_to_chai_id[colabfold_id] = chai_seq_matches.pop().header.split(
"|", maxsplit=1
)[-1]

# Gather templates
gather_colabfold_templates(
colabfold_out_dir=colabfold_out_dir,
identifier=identifier,
chain_id_mapping=colab_id_to_chai_id,
output_folder=chai_out_folder,
)

# Write the actual fasta input file
write_fastas(sequences, (chai_out_folder / "chai.fasta").as_posix())


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
app()

0 comments on commit d591369

Please sign in to comment.