diff --git a/chai_lab/chai1.py b/chai_lab/chai1.py index f25ac9b..fa0f264 100644 --- a/chai_lab/chai1.py +++ b/chai_lab/chai1.py @@ -902,9 +902,9 @@ def avg_per_token_1d(x): bfactors=scaled_plddt_scores_per_atom, output_batch=inputs, write_path=cif_out_path, - entity_names={ - c.entity_data.entity_id: c.entity_data.entity_name - for c in feature_context.chains + asym_entity_names={ + i: c.entity_data.entity_name + for i, c in enumerate(feature_context.chains, start=1) }, ) cif_paths.append(cif_out_path) diff --git a/chai_lab/data/io/cif_utils.py b/chai_lab/data/io/cif_utils.py index 5a65b97..98d65b9 100644 --- a/chai_lab/data/io/cif_utils.py +++ b/chai_lab/data/io/cif_utils.py @@ -72,7 +72,11 @@ def token_centre_plddts( return plddts[atom_idces].tolist(), residue_indices.tolist() -def get_chains_metadata(context: PDBContext, entity_names) -> dict[int, AsymUnit]: +def get_chains_metadata( + context: PDBContext, asymid2entity_name: dict[int, str] +) -> dict[int, AsymUnit]: + """Return mapping from asym id to AsymUnit objects.""" + assert context.asym_id2entity_type.keys() == asymid2entity_name.keys() # for each chain, get chain id, entity id, full sequence token_res_names = context.token_res_names_to_string @@ -101,15 +105,15 @@ def get_chains_metadata(context: PDBContext, entity_names) -> dict[int, AsymUnit sequence = [chain_token_res_names[i] for i in any_token_in_resi] - entity_id = context.token_entity_id[token_indices[0]] - chain_id_str = _get_chain_letter(asym_id) asym_id2asym_unit[asym_id] = AsymUnit( entity=Entity( # sequence is a list of ChemComponents for aminoacids/bases - sequence=[_to_chem_component(resi, entity_type) for resi in sequence], - description=entity_names[int(entity_id)], + sequence=[ + _to_chem_component(resi, entity_type, asym_id) for resi in sequence + ], + description=asymid2entity_name[asym_id], ), details=f"Chain {chain_id_str}", id=chain_id_str, @@ -118,11 +122,10 @@ def get_chains_metadata(context: PDBContext, entity_names) -> dict[int, AsymUnit return asym_id2asym_unit -def _to_chem_component(res_name_3: str, entity_type: int): +def _to_chem_component(res_name_3: str, entity_type: int, asym_id: int): match entity_type: case EntityType.LIGAND.value: - code = res_name_3 - return NonPolymerChemComp(res_name_3) + return NonPolymerChemComp(id=res_name_3 + str(asym_id)) case EntityType.MANUAL_GLYCAN.value: return SaccharideChemComp(id=res_name_3, name=res_name_3) case EntityType.PROTEIN.value: @@ -140,11 +143,12 @@ def _to_chem_component(res_name_3: str, entity_type: int): raise NotImplementedError(f"Cannot handle entity type: {entity_type}") +@typecheck def save_to_cif( coords: Float[Tensor, "1 n_atoms 3"], output_batch: dict, write_path: Path, - entity_names: dict[int, str], + asym_entity_names: dict[int, str], bfactors: Float[Tensor, "1 n_atoms"] | None = None, ): write_path.parent.mkdir(parents=True, exist_ok=True) @@ -153,7 +157,7 @@ def save_to_cif( coords=rearrange(coords, "1 n c -> n c", c=3).cpu(), plddts=None if bfactors is None else rearrange(bfactors, "1 n -> n").cpu(), context=pdb_context_from_batch(output_batch), - entity_names=entity_names, + asym_entity_names=asym_entity_names, out_path=write_path, ) logger.info(f"saved cif file to {write_path}") @@ -164,10 +168,12 @@ def new_context_to_cif_atoms( coords: Float[Tensor, "n_atoms 3"], plddts: Float[Tensor, "n_atoms"] | None, context: PDBContext, - entity_names: dict[int, str], + asym_entity_names: dict[int, str], out_path: Path, ): - asym_id2asym_unit = get_chains_metadata(context, entity_names=entity_names) + asym_id2asym_unit = get_chains_metadata( + context, asymid2entity_name=asym_entity_names + ) atom_asym_id = context.token_asym_id[context.atom_token_index] # atom level attributes