Skip to content

Commit

Permalink
Lots of tests... Some small bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
John-Peters-UW committed Nov 5, 2024
1 parent ba7e603 commit a2ac01b
Show file tree
Hide file tree
Showing 10 changed files with 21,711 additions and 60 deletions.
105 changes: 83 additions & 22 deletions metl/model_encoder.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,32 @@
import torch
from typing import Literal
from biopandas.pdb import PandasPdb
from metl.encode import DataEncoder
import metl.relative_attention as ra
from Bio.SeqUtils import seq1
import os
import warnings

class ModelEncoder(torch.nn.Module):
def __init__(self, model, encoder, strict=True, indexing:Literal[0,1] = 0) -> None:
def __init__(self, model: torch.nn.Sequential, encoder: DataEncoder,
strict:bool=True, indexing:Literal[0,1] = 0) -> None:
"""Wrapper to provide input sanitization and validation for METL model and encoders
Args:
model (torch.nn.Sequential): METL model loaded from Zenodo
encoder (DataEncoder): METL encoder for the above model loaded from Zenodo
strict (bool, optional): Strict-mode requires PDB files to match WT sequence. Defaults to True.
indexing (Literal[0,1], optional): What indexing the input variants use.
Defaults to 0 as METL is programmed to be used with 0 based indexing.
Raises:
Exception: Throws an exception if the indexing is invalid
"""

super(ModelEncoder, self).__init__()

if indexing != 0 and indexing != 1:
raise Exception("Indexing must be equal to 0 or to 1.")
raise AssertionError("Indexing must be equal to 0 or to 1.")

self.model = model
self.encoder = encoder
Expand All @@ -20,35 +36,69 @@ def __init__(self, model, encoder, strict=True, indexing:Literal[0,1] = 0) -> No

self.needs_pdb = self.check_if_pdb_needed(model)

def check_if_pdb_needed(self, model):
def check_if_pdb_needed(self, model: torch.nn.Sequential):

sequential = next(model.children())

for layer in sequential:
if isinstance(layer, ra.RelativeTransformerEncoder):
return True
return False

def validate_pdb(self, pdb_file, wt):
def validate_pdb(self, pdb_file:str , wt: str|list[str]):
"""
When validating a PDB, it is possible that the PDB file and wild type (wt) passed will differ.
Strict raises an exception if this occurs, otherwise this potential error is not checked.
Strict is off by default when loading from a checkpoint file, and on when loading models from Zenodo.
Args:
pdb_file (str): The path to the PDB file
wt (str): The string representing the wild type sequence
Raises:
Exception: Raises exceptions for multi-chain input PDB files or files unable to be loaded by pandaspdb
"""

# Check valid
try:
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") # All warnings triggered

# Code that might trigger warnings
PandasPdb().read_pdb(pdb_file)

if len(w) > 0:
warns = [str(warn.message) for warn in w]
joined_warnings = "\n".join(warns)
raise AssertionError(f"Pandas PDB is giving a warning, this usually means a PDB file is wrong: \n{joined_warnings}")

ppdb = PandasPdb().read_pdb(pdb_file)
except Exception as e:
raise Exception(f"{str(e)} \n\n PDB file could not be read by PandasPDB. It may be incorrectly formatted.")
raise ValueError(f"{e} \n\n PDB file could not be read by PandasPDB. It may be incorrectly formatted.")

# Check multi-chain
num_chains = ppdb.df['ATOM'].groupby('chain_id').size().size

if num_chains != 1:
raise ValueError(f"PDB file has {num_chains} chains. METL only supports single chain PDB files.")

groups = ppdb.df['ATOM'].groupby('residue_number')
wt_seq = []
for group_name, group_data in groups:
wt_seq.append(seq1(group_data.iloc[0]['residue_name']))
wildtype = ''.join(wt_seq)

if self.strict:
if self.strict and isinstance(wt, str):
err_str = "Strict mode is on because a METL model that we trained was used. Wildtype and PDB sequences must match."
err_str += " To ignore the sequence mismatch, pass strict=False to the load function you used."
assert wildtype == wt, err_str
elif isinstance(wt, list):
assert all([isinstance(AASeq, str) for AASeq in wt]), "One or more sequence in the list of sequences you passed was not a string."
# WT seqs must be the same length
for seq in wt:
err_str = "One of the sequences in the list of sequences you passed was not the same length as the "
err_str += "first sequence. All sequences must be the same length."
assert len(seq) == len(wt[0]), err_str

def validate_variants(self, variants, wt):
"""
Expand All @@ -62,21 +112,21 @@ def validate_variants(self, variants, wt):
to_amino_acid = mutation[-1]
location = int(mutation[1:-1])

errors = []
error = None

if location <= 0 or location >= wt_len-1:
if location < 0 or location >= wt_len:
error_str = f"The position for the mutation is {location} but it needs to be between 0 "
error_str += f"and {len(wt)-1} if 0-based and 1 and {len(wt)} if 1-based."
errors.append(error_str)
error = error_str
elif wt[location] != from_amino_acid:
errors.append(f"Wildtype at position {location} is {wt[location]} but variant had {from_amino_acid}. Check the variant input.")
error = f"Wildtype at position {location} is {wt[location]} but variant had {from_amino_acid}. Check the variant input."

if len(errors) != 0:
if error is not None:
if self.indexing == 1:
mutation = f"{from_amino_acid}{location+1}{to_amino_acid}"
one_based_variants = self.change_indexing_to(1, variants)

raise Exception(f"Invalid mutation {mutation} that is inside variant {one_based_variants[index]}. Errors: {', '.join(errors)}")
raise AssertionError(f"Invalid mutation {mutation} that is inside variant {one_based_variants[index]}. Error: {error}")

def change_indexing_to(self, indexing, variants):
changed_based_variants = []
Expand All @@ -98,24 +148,35 @@ def change_indexing_to(self, indexing, variants):

return changed_based_variants

def forward(self, wt:str, variants:list[str], pdb_fn:str=None):
if self.needs_pdb and pdb_fn is None:
raise Exception("PDB path is required but it was not given. Do you have a PDB file?")
def forward(self, wt:str|list[str], variants:list[str]=None, pdb_fn:str=None):
if isinstance(wt, str) and self.needs_pdb and pdb_fn is None:
raise AssertionError("PDB path is required but it was not given. Do you have a PDB file?")

if pdb_fn:
pdb_fn = os.path.abspath(pdb_fn)
pdb_fn = os.path.abspath(os.path.expanduser(pdb_fn))
self.validate_pdb(pdb_fn, wt)

encoded_variants = None
if isinstance(wt, list):
# Reusing this variable so we can have a simpler program flow
assert all([isinstance(AASeq, str) for AASeq in wt]), "All sequences in wt must be type str"
encoded_variants = self.encoder.encode_sequences(wt)

if self.indexing == 1:
variants = self.change_indexing_to(0, variants)
if variants is not None:
if isinstance(variants, str):
assert ValueError("Variants is just a string. Did you forget to use pdb_fn by keyword argument, when trying to pass in multiple sequences?")
if self.indexing == 1:
variants = self.change_indexing_to(0, variants)

self.validate_variants(variants, wt)
self.validate_variants(variants, wt)

encoded_variants = self.encoder.encode_variants(wt, variants)
encoded_variants = self.encoder.encode_variants(wt, variants)

assert encoded_variants is not None, "Encoding did not happen. If wt is a string, you must pass variants. If wt is a list, you can not pass variants."

if pdb_fn:
pred = self.model(torch.Tensor(encoded_variants), pdb_fn=pdb_fn)
pred = self.model(torch.tensor(encoded_variants), pdb_fn=pdb_fn)
else:
pred = self.model(torch.Tensor(encoded_variants))
pred = self.model(torch.tensor(encoded_variants))

return pred
15 changes: 15 additions & 0 deletions notebooks/bleh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from transformers import AutoModel
import torch

metl = AutoModel.from_pretrained('gitter-lab/METL', trust_remote_code=True)

model = "metl-l-2m-3d-gb1"
wt = "MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDDATKTFTVTE"
variants = '["T17P,T54F", "V28L,F51A"]'
pdb_path = './2qmt_p.pdb'

metl.load_from_ident(model)
metl.eval()

with torch.no_grad():
predictions = metl(wt, variants, pdb_path)
117 changes: 84 additions & 33 deletions notebooks/test_validation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import metl\n",
"import torch\n",
"from biopandas.pdb import PandasPdb\n",
"from Bio.SeqUtils import seq1\n",
"\n",
"import metl.relative_attention as ra"
"from biopandas.pdb import PandasPdb"
]
},
{
Expand Down Expand Up @@ -46,48 +45,100 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4\n"
]
},
{
"ename": "Exception",
"evalue": "PDB file has 4 chains. METL only supports single chain PDB files.",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mException\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[4], line 2\u001b[0m\n\u001b[0;32m 1\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m torch\u001b[38;5;241m.\u001b[39mno_grad():\n\u001b[1;32m----> 2\u001b[0m pred \u001b[38;5;241m=\u001b[39m \u001b[43mmodel_needs_pdb\u001b[49m\u001b[43m(\u001b[49m\u001b[43mwt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mvariants\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43m../pdbs/8hgs_multi_chain.pdb\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m 3\u001b[0m \u001b[38;5;66;03m# pred2 = model_no_pdb(wt, variants)\u001b[39;00m\n\u001b[0;32m 4\u001b[0m \n\u001b[0;32m 5\u001b[0m \u001b[38;5;66;03m# pred, pred2\u001b[39;00m\n",
"File \u001b[1;32mc:\\Users\\johng\\Workspaces\\gitter_stuff\\metl-pretrained\\venv\\lib\\site-packages\\torch\\nn\\modules\\module.py:1553\u001b[0m, in \u001b[0;36mModule._wrapped_call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1551\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compiled_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[0;32m 1552\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m-> 1553\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_call_impl(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n",
"File \u001b[1;32mc:\\Users\\johng\\Workspaces\\gitter_stuff\\metl-pretrained\\venv\\lib\\site-packages\\torch\\nn\\modules\\module.py:1562\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[1;34m(self, *args, **kwargs)\u001b[0m\n\u001b[0;32m 1557\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[0;32m 1558\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[0;32m 1559\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks\n\u001b[0;32m 1560\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[0;32m 1561\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[1;32m-> 1562\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m forward_call(\u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs)\n\u001b[0;32m 1564\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n\u001b[0;32m 1565\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n",
"File \u001b[1;32mc:\\Users\\johng\\Workspaces\\gitter_stuff\\metl-pretrained\\venv\\lib\\site-packages\\metl\\model_encoder.py:144\u001b[0m, in \u001b[0;36mModelEncoder.forward\u001b[1;34m(self, wt, variants, pdb_fn)\u001b[0m\n\u001b[0;32m 142\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m pdb_fn:\n\u001b[0;32m 143\u001b[0m pdb_fn \u001b[38;5;241m=\u001b[39m os\u001b[38;5;241m.\u001b[39mpath\u001b[38;5;241m.\u001b[39mabspath(pdb_fn)\n\u001b[1;32m--> 144\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalidate_pdb\u001b[49m\u001b[43m(\u001b[49m\u001b[43mpdb_fn\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mwt\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 146\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mindexing \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[0;32m 147\u001b[0m variants \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mchange_indexing_to(\u001b[38;5;241m0\u001b[39m, variants)\n",
"File \u001b[1;32mc:\\Users\\johng\\Workspaces\\gitter_stuff\\metl-pretrained\\venv\\lib\\site-packages\\metl\\model_encoder.py:71\u001b[0m, in \u001b[0;36mModelEncoder.validate_pdb\u001b[1;34m(self, pdb_file, wt)\u001b[0m\n\u001b[0;32m 69\u001b[0m \u001b[38;5;28mprint\u001b[39m(num_chains)\n\u001b[0;32m 70\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m num_chains \u001b[38;5;241m!=\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m---> 71\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mPDB file has \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_chains\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m chains. METL only supports single chain PDB files.\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 73\u001b[0m groups \u001b[38;5;241m=\u001b[39m ppdb\u001b[38;5;241m.\u001b[39mdf[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mATOM\u001b[39m\u001b[38;5;124m'\u001b[39m]\u001b[38;5;241m.\u001b[39mgroupby(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mresidue_number\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[0;32m 74\u001b[0m wt_seq \u001b[38;5;241m=\u001b[39m []\n",
"\u001b[1;31mException\u001b[0m: PDB file has 4 chains. METL only supports single chain PDB files."
]
}
],
"source": [
"with torch.no_grad():\n",
" pred = model_needs_pdb(wt, variants, '../pdbs/8hgs_multi_chain.pdb')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[ 4.7697e-01, 2.8120e-01, -9.2805e-02, 1.3844e+00, -1.6826e-01,\n",
" 3.0760e-01, 2.9817e-01, -1.3989e+00, -6.9998e-03, 2.9667e+00,\n",
" 5.2389e-01, -1.1593e-01, -6.0801e-01, 1.3960e+00, -7.3188e-01,\n",
" 4.5726e-01, 2.4292e-01, -6.2117e-01, -3.2025e-01, -2.7432e-03,\n",
" 1.0857e+00, 7.2377e-01, -1.8836e-01, 1.2242e+00, -2.0520e-01,\n",
" -1.2772e-01, -5.5548e-01, 1.7935e-01, -4.5334e-02, 5.6413e-01,\n",
" 7.0476e-01, 2.2904e-01, 1.3629e-02, -3.3900e-01, 9.5354e-01,\n",
" 2.7879e-01, 1.5288e+00, 6.2959e-02, -9.8674e-01, 9.6503e-01,\n",
" 1.1230e+00, 7.0476e-01, -1.5055e+00, 3.4333e-01, 1.8991e+00,\n",
" 7.1499e-01, -3.8434e-01, -5.2042e-01, 2.5414e+00, 1.2656e+00,\n",
" -8.8348e-01, 2.6817e+00, -7.5841e-02, 2.7973e+00, -1.7446e-01],\n",
" [-7.1824e-01, -1.0838e-01, -3.0332e-01, 4.3212e-01, -2.2647e-01,\n",
" -6.0814e-01, -3.1514e-01, -9.7168e-01, 2.2807e-01, -4.4759e-01,\n",
" 5.1764e-01, -6.7759e-01, 5.7723e-01, 5.8746e-02, -9.1098e-01,\n",
" -1.7514e-01, -5.6852e-01, 3.0986e-01, -6.1122e-01, -2.3290e-01,\n",
" 3.0098e-01, 6.0075e-01, 6.4402e-01, 3.9483e-01, 3.4120e-01,\n",
" 1.3791e-01, 7.0088e-01, 2.0338e-01, 1.0337e+00, -2.1346e-01,\n",
" -2.6859e-02, 2.6972e-01, 4.1215e-01, 2.8373e-01, -5.9371e-01,\n",
" 4.7806e-01, 1.2857e-01, 3.1594e-02, -2.4400e-01, 1.0700e-01,\n",
" 5.0299e-02, -2.6863e-02, 6.8558e-01, -3.8530e-01, -2.0537e-01,\n",
" -1.3260e-01, -1.0593e+00, -1.3271e-03, -4.2777e-02, 2.2335e-01,\n",
" -3.1454e-01, -7.8193e-02, -2.5404e-02, 3.4191e-01, -1.7368e-01]]),\n",
" tensor([[-3.6763],\n",
" [-3.2601]]))"
"'E'"
]
},
"execution_count": 4,
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"with torch.no_grad():\n",
" pred = model_needs_pdb(wt, variants, pdb_fn)\n",
" pred2 = model_no_pdb(wt, variants)\n",
"pred, pred2"
"\"MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDDATKTFTVTE\"[55]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ppdb = PandasPdb().read_pdb('../pdbs/8hgs_multi_chain.pdb')\n",
"scppdb = PandasPdb().read_pdb('../pdbs/1gfl_cm.pdb')\n",
"for pdb_obj in [ppdb, scppdb]:\n",
" pdb_df:pd.DataFrame = pdb_obj.df['ATOM']\n",
" print(pdb_df.groupby('chain_id').size().size)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# from transformers import AutoModel\n",
"# import torch\n",
"\n",
"# metl = AutoModel.from_pretrained('gitter-lab/METL', trust_remote_code=True)\n",
"\n",
"# model = \"metl-l-2m-3d-gb1\"\n",
"# wt = \"MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVDGEWTYDDATKTFTVTE\"\n",
"# variants = '[\"T17P,T54F\", \"V28L,F51A\"]'\n",
"# pdb_path = './2qmt_p.pdb'\n",
"\n",
"# metl.load_from_ident(model)\n",
"# metl.eval()\n",
"\n",
"# with torch.no_grad(): \n",
"# predictions = metl(wt, variants, pdb_path)"
]
}
],
Expand Down
Loading

0 comments on commit a2ac01b

Please sign in to comment.