diff --git a/chai_lab/chai1.py b/chai_lab/chai1.py index c630618..4ac2031 100644 --- a/chai_lab/chai1.py +++ b/chai_lab/chai1.py @@ -7,6 +7,7 @@ from collections import Counter from dataclasses import dataclass from pathlib import Path +from tempfile import TemporaryDirectory import numpy as np import torch @@ -313,12 +314,13 @@ def run_inference( for chain in chains if chain.entity_data.entity_type == EntityType.PROTEIN ] - msa_dir = output_dir / "mmseqs_msas" - msa_dir.mkdir(exist_ok=False, parents=True) - generate_colabfold_msas(protein_seqs=protein_sequences, msa_dir=msa_dir) - msa_context, msa_profile_context = get_msa_contexts( - chains, msa_directory=msa_dir - ) + # Save MSAs to temporary directory to ensure we never clobber anything + with TemporaryDirectory() as tmpdir: + msa_dir = Path(tmpdir) + generate_colabfold_msas(protein_seqs=protein_sequences, msa_dir=msa_dir) + msa_context, msa_profile_context = get_msa_contexts( + chains, msa_directory=msa_dir + ) elif msa_directory is not None: msa_context, msa_profile_context = get_msa_contexts( chains, msa_directory=msa_directory