diff --git a/chai_lab/data/dataset/templates/load.py b/chai_lab/data/dataset/templates/load.py index e706e87..278ad50 100644 --- a/chai_lab/data/dataset/templates/load.py +++ b/chai_lab/data/dataset/templates/load.py @@ -92,6 +92,10 @@ def __post_init__(self): def query_identifier(self) -> str: return self.template_hit.query_pdb_id + @property + def hit_identifier(self) -> str: + return f"{self.template_hit.pdb_id}|{self.template_hit.chain_id}" + @property def template_hit_indices(self) -> Int[Tensor, "n_tokens"]: """Indices of hit within full hit sequence.""" @@ -356,25 +360,32 @@ def get_template_data( # Check that the loaded version based on an AllAtomStructureContext matches what # we expect from the "raw" TemplateHit if strict_subsequence_check: - original_alignment = "".join( - [ - c - for c in template_hit.query_seq_realigned - if c.isupper() or not c.isalpha() # Keep upper and "-" - ] - ) - context_alignment = "".join( - [ - rc.residue_types_with_nucleotides[i] - for i in template.template_restype.tolist() - ] - ) - # NOTE before these are all loaded into a TemplateContext, they do not - # contain any flanking gap chars "-" so we do not check that the flanking - # gaps are equal. - if context_alignment not in original_alignment: + try: + original_alignment = "".join( + [ + c + for c in template_hit.query_seq_realigned + if c.isupper() or not c.isalpha() # Keep upper and "-" + ] + ) + context_alignment = "".join( + [ + rc.residue_types_with_nucleotides[i] + for i in template.template_restype.tolist() + ] + ) + # NOTE before these are all loaded into a TemplateContext, they do not + # contain any flanking gap chars "-" so we do not check that the flanking + # gaps are equal. + if context_alignment not in original_alignment: + logger.warning( + f"Skipping {template_hit} due to mismatched sequences: {context_alignment=} {original_alignment=}" + ) + continue + except Exception: logger.warning( - f"Mismatched sequences loading {template_hit}: {context_alignment=} {original_alignment=}" + f"Skipping {template_hit} due to exception when checking sequences", + exc_info=True, ) continue @@ -393,4 +404,7 @@ def get_template_data( f"Templates for {pdb_id} | {len(template_data)} remain, dropped hits: {drop_count}" ) + logger.info( + f"Loaded {len(template_data)} templates: {[t.hit_identifier for t in template_data]}" + ) return template_data