Skip to content

Commit

Permalink
Staging script
Browse files Browse the repository at this point in the history
  • Loading branch information
wukevin committed Feb 17, 2025
1 parent a83b639 commit ca77e0e
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 7 deletions.
7 changes: 5 additions & 2 deletions chai_lab/data/parsing/fasta.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# See the LICENSE file for details.

import logging
from io import StringIO
from pathlib import Path
from typing import NamedTuple, Sequence

Expand Down Expand Up @@ -30,10 +31,12 @@ def fastas_to_str(fastas: Sequence[Fasta]) -> str:
return "".join(f">{fasta.header}\n{fasta.sequence}\n" for fasta in fastas)


def read_fasta(file_path: str | Path) -> list[Fasta]:
def read_fasta(file_path: str | Path | StringIO) -> list[Fasta]:
from Bio import SeqIO

fasta_sequences = SeqIO.parse(open(file_path), "fasta")
fasta_sequences = SeqIO.parse(
open(file_path) if isinstance(file_path, (str, Path)) else file_path, "fasta"
)
return [Fasta(fasta.description, str(fasta.seq)) for fasta in fasta_sequences]


Expand Down
9 changes: 5 additions & 4 deletions chai_lab/data/parsing/msas/a3m.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,16 @@ def tokenize_sequences_to_arrays(
def read_colabfold_a3m(fname: Path) -> dict[str, list[Fasta]]:
"""Returns mapping of MSA hits per identifier in the given a3m file.
The query line in each block of MSA hits is dropped.
The query line in each block of MSA hits is retained.
"""
text = fname.read_text()
retval: dict[str, list[Fasta]] = {}
for block in text.split("\x00"): # Splits on null byte
if not block:
continue
strio = StringIO(block)
query, *hits = read_fasta(strio)
assert re.match(r"^[0-9]{3}$", query.header)
retval[query.header] = hits
hits = read_fasta(strio)
assert len(hits) > 0
assert re.match(r"^[0-9]{3}$", (query := hits[0].header))
retval[query] = hits
return retval
2 changes: 1 addition & 1 deletion chai_lab/model/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def get_asym_id_from_subchain_id(
subchain_id: str,
source_pdb_chain_id: UInt8[Tensor, "n_tokens 4"],
token_asym_id: Int[Tensor, "n"],
):
) -> int:
# encde the subchain ids and perform lookup in context features
chain_id_tensorcode = string_to_tensorcode(subchain_id, pad_to_length=4)
chain_id_tensorcode = chain_id_tensorcode.to(token_asym_id.device)
Expand Down
162 changes: 162 additions & 0 deletions scripts/stage_for_chai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright (c) 2024 Chai Discovery, Inc.
# Licensed under the Apache License, Version 2.0.
# See the LICENSE file for details.
"""
Stage the folders for chai
"""

import logging
from pathlib import Path

import pandas as pd
import typer

from chai_lab.data.io.cif_utils import get_chain_letter
from chai_lab.data.parsing.fasta import Fasta, write_fastas
from chai_lab.data.parsing.msas.a3m import read_colabfold_a3m
from chai_lab.data.parsing.msas.aligned_pqt import (
AlignedParquetModel,
expected_basename,
)
from chai_lab.data.parsing.msas.data_source import MSADataSource
from chai_lab.data.parsing.templates.m8 import parse_m8_file

app = typer.Typer(pretty_exceptions_enable=False)


def read_colabfold_inputs(fname: Path) -> dict[str, list[Fasta]]:
"""Extracts sequences from colabfold input table."""
df = pd.read_csv(fname, delimiter=",")
assert list(df.columns) == ["id", "sequence"]
retval: dict[str, list[Fasta]] = {}
for row in df.itertuples():
sequences: list[str] = row.sequence.split(":") # type: ignore
complex: list[Fasta] = [
Fasta(header=f"protein|{get_chain_letter(i)}", sequence=seq)
for i, seq in enumerate(sequences, start=1)
]
retval[row.id] = complex # type: ignore
return retval


def gather_colabfold_msas(
colabfold_out_dir: Path, identifier: str, output_folder: Path
) -> dict[str, str]:
"""Gathers MSAs generated by colabfold and writes them to the given output folder.
Returns mapping of colabfold generated identifiers -> sequences.
"""
output_folder.mkdir(parents=True, exist_ok=True)
paired_msa = read_colabfold_a3m(
colabfold_out_dir / f"{identifier}_pairgreedy/pair.a3m"
)
# The paired MSA should be the same number of rows for all
paired_lengths = set(len(v) for v in paired_msa.values())
assert len(paired_lengths) == 1
n_paired = paired_lengths.pop()
logging.info(f"[{identifier}] Colabfold paired {n_paired} MSAs")

# Read in also the single chain MSAs
uniref_msa = read_colabfold_a3m(colabfold_out_dir / f"{identifier}_env/uniref.a3m")

env_msa = read_colabfold_a3m(
colabfold_out_dir / f"{identifier}_env/bfd.mgnify30.metaeuk30.smag30.a3m"
)
assert set(uniref_msa.keys()) == set(env_msa.keys()) == set(paired_msa.keys())

retval: dict[str, str] = {}
for query in paired_msa.keys():
query_seq = uniref_msa[query][0].sequence
msa_rows = []
for i, row in enumerate(paired_msa[query]):
record = {
"sequence": row.sequence,
"source_database": (
MSADataSource.QUERY if i == 0 else MSADataSource.UNIREF90
).value,
"pairing_key": str(i) if i > 0 else "",
"comment": "null",
}
msa_rows.append(record)
for row in uniref_msa[query][1:]:
msa_rows.append(
{
"sequence": row.sequence,
"source_database": MSADataSource.UNIREF90.value,
"pairing_key": "",
"comment": "null",
}
)
for row in env_msa[query][1:]:
msa_rows.append(
{
"sequence": row.sequence,
"source_database": MSADataSource.BFD_UNICLUST.value,
"pairing_key": "",
"comment": "null",
}
)
table = pd.DataFrame.from_records(msa_rows)
AlignedParquetModel.validate(table)
table.to_parquet(output_folder / expected_basename(query_sequence=query_seq))
retval[query] = query_seq
return retval


def gather_colabfold_templates(
colabfold_out_dir: Path,
identifier: str,
chain_id_mapping: dict[str, str],
output_folder: Path,
) -> Path:
template_file = colabfold_out_dir / f"{identifier}_env" / "pdb70.m8"
assert template_file.is_file()
templates = parse_m8_file(template_file)
templates["query_id"] = templates["query_id"].apply(
lambda s: chain_id_mapping[str(s)]
)
outfile = output_folder / "all_template_hits.m8"
templates.to_csv(outfile, sep="\t", index=False, header=False)
return outfile


@app.command()
def main(colabfold_out_dir: Path, chai_dir: Path):
"""Takes a directory containing colabfold outputs and stages them for Chai1."""
csv_files = list(colabfold_out_dir.glob("*.csv"))
assert len(csv_files) == 1, f"Expected a single csv file but got {len(csv_files)}"
fasta_entries: dict[str, list[Fasta]] = read_colabfold_inputs(csv_files.pop())

for identifier, sequences in fasta_entries.items():
chai_out_folder = chai_dir / identifier
chai_out_folder.mkdir(parents=True, exist_ok=True)
colabfold_id_to_seq = gather_colabfold_msas(
colabfold_out_dir=colabfold_out_dir,
identifier=identifier,
output_folder=chai_out_folder / "msas",
)
assert set(colabfold_id_to_seq.values()) == set([f.sequence for f in sequences])

# Build a mapping for each sequence in the input to the
colab_id_to_chai_id = {}
for colabfold_id, seq in colabfold_id_to_seq.items():
chai_seq_matches = [s for s in sequences if s.sequence == seq]
assert len(chai_seq_matches)
colab_id_to_chai_id[colabfold_id] = chai_seq_matches.pop().header.split(
"|", maxsplit=1
)[-1]

gather_colabfold_templates(
colabfold_out_dir=colabfold_out_dir,
identifier=identifier,
chain_id_mapping=colab_id_to_chai_id,
output_folder=chai_out_folder,
)

# Write the actual fasta input file
write_fastas(sequences, (chai_out_folder / "chai.fasta").as_posix())


if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
app()

0 comments on commit ca77e0e

Please sign in to comment.