Skip to content

Commit

Permalink
Better fault tolerance
Browse files Browse the repository at this point in the history
  • Loading branch information
wukevin committed Feb 19, 2025
1 parent 9bae054 commit 1621c73
Showing 1 changed file with 32 additions and 18 deletions.
50 changes: 32 additions & 18 deletions chai_lab/data/dataset/templates/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ def __post_init__(self):
def query_identifier(self) -> str:
return self.template_hit.query_pdb_id

@property
def hit_identifier(self) -> str:
return f"{self.template_hit.pdb_id}|{self.template_hit.chain_id}"

@property
def template_hit_indices(self) -> Int[Tensor, "n_tokens"]:
"""Indices of hit within full hit sequence."""
Expand Down Expand Up @@ -356,25 +360,32 @@ def get_template_data(
# Check that the loaded version based on an AllAtomStructureContext matches what
# we expect from the "raw" TemplateHit
if strict_subsequence_check:
original_alignment = "".join(
[
c
for c in template_hit.query_seq_realigned
if c.isupper() or not c.isalpha() # Keep upper and "-"
]
)
context_alignment = "".join(
[
rc.residue_types_with_nucleotides[i]
for i in template.template_restype.tolist()
]
)
# NOTE before these are all loaded into a TemplateContext, they do not
# contain any flanking gap chars "-" so we do not check that the flanking
# gaps are equal.
if context_alignment not in original_alignment:
try:
original_alignment = "".join(
[
c
for c in template_hit.query_seq_realigned
if c.isupper() or not c.isalpha() # Keep upper and "-"
]
)
context_alignment = "".join(
[
rc.residue_types_with_nucleotides[i]
for i in template.template_restype.tolist()
]
)
# NOTE before these are all loaded into a TemplateContext, they do not
# contain any flanking gap chars "-" so we do not check that the flanking
# gaps are equal.
if context_alignment not in original_alignment:
logger.warning(
f"Skipping {template_hit} due to mismatched sequences: {context_alignment=} {original_alignment=}"
)
continue
except Exception:
logger.warning(
f"Mismatched sequences loading {template_hit}: {context_alignment=} {original_alignment=}"
f"Skipping {template_hit} due to exception when checking sequences",
exc_info=True,
)
continue

Expand All @@ -393,4 +404,7 @@ def get_template_data(
f"Templates for {pdb_id} | {len(template_data)} remain, dropped hits: {drop_count}"
)

logger.info(
f"Loaded {len(template_data)} templates: {[t.hit_identifier for t in template_data]}"
)
return template_data

0 comments on commit 1621c73

Please sign in to comment.