From 15e2b382a69c5f449281595ce40438c375f0c488 Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Mon, 9 Dec 2024 07:33:25 +0000 Subject: [PATCH] Expose helper function to merge a3ms in a directory --- chai_lab/data/parsing/msas/aligned_pqt.py | 15 +++++++++------ chai_lab/main.py | 5 +++++ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/chai_lab/data/parsing/msas/aligned_pqt.py b/chai_lab/data/parsing/msas/aligned_pqt.py index a9a2cee..80a5679 100644 --- a/chai_lab/data/parsing/msas/aligned_pqt.py +++ b/chai_lab/data/parsing/msas/aligned_pqt.py @@ -9,7 +9,7 @@ import logging from functools import lru_cache from pathlib import Path -from typing import Literal, Mapping +from typing import Literal, Mapping, Optional import pandas as pd import pandera as pa @@ -174,7 +174,7 @@ def merge_multi_a3m_to_aligned_dataframe( msa_a3m_files: Mapping[Path, MSADataSource], insert_keys_for_sources: Literal["all", "none", "uniprot"] = "uniprot", ) -> pd.DataFrame: - """Merge multiple a3m files into a single aligned parquet file.""" + """Merge multiple a3ms from the same query sequence into a single aligned parquet.""" dfs = { src: a3m_to_aligned_dataframe( a3m_path, @@ -198,10 +198,10 @@ def merge_multi_a3m_to_aligned_dataframe( return pd.concat(chunks, ignore_index=True).reset_index(drop=True) -def _merge_files_in_directory(directory: str): +def merge_a3m_in_directory(directory: str, output_directory: Optional[str] = None): """Finds .a3m files in a directory and combine them into a single aligned.pqt file. Files are expected to be named like hits_uniref90.a3m (uniref90 is the source database). - All files in the directoroy are assumed to be derived from the same query sequence. + All files in the directory are assumed to be derived from the same query sequence. Provided as a example commandline interface to merge files. """ @@ -226,7 +226,10 @@ def _merge_files_in_directory(directory: str): ) # Get the query sequence and use it to determine where we save the file. query_seq: str = df.iloc[0]["sequence"] - df.to_parquet(dir_path / expected_basename(query_seq)) + # Default to writing into the same directory if output directory isn't specified + outdir = Path(output_directory) if output_directory is not None else dir_path + outdir.mkdir(exist_ok=True, parents=True) + df.to_parquet(outdir / expected_basename(query_seq)) if __name__ == "__main__": @@ -234,4 +237,4 @@ def _merge_files_in_directory(directory: str): logging.basicConfig(level=logging.INFO) - typer.run(_merge_files_in_directory) + typer.run(merge_a3m_in_directory) diff --git a/chai_lab/main.py b/chai_lab/main.py index 8132d65..7390b5a 100644 --- a/chai_lab/main.py +++ b/chai_lab/main.py @@ -9,6 +9,7 @@ import typer from chai_lab.chai1 import run_inference +from chai_lab.data.parsing.msas.aligned_pqt import merge_a3m_in_directory logging.basicConfig(level=logging.INFO) @@ -35,6 +36,10 @@ def citation(): def cli(): app = typer.Typer() app.command("fold", help="Run Chai-1 to fold a complex.")(run_inference) + app.command( + "a3m-to-pqt", + help="Convert a3m files for a *single sequence* into a aligned parquet file", + )(merge_a3m_in_directory) app.command("citation", help="Print citation information")(citation) app()