Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add scripts for clustering #1

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
99 changes: 99 additions & 0 deletions data/benchmark_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from argparse import ArgumentParser
import pandas as pd
from pathlib import Path

def filter_sample(clusters_csv,benchmark_clusters,t):
cl_bench=benchmark_clusters.loc[:,'cluster'].value_counts()
new_bench=benchmark_clusters['id'].tolist()
bench_deleted=[]
cluster_deleted=pd.DataFrame()
for _,row in benchmark_clusters.iterrows():
if row['size']/cl_bench[row['cluster']]>t:
new_bench.remove(row['id'])
bench_deleted.append(row['id'])
else:
cluster_deleted=pd.concat([cluster_deleted,clusters_csv[clusters_csv[clustering_regions].isin([row['cluster']])]])
clusters_csv=clusters_csv[~clusters_csv[clustering_regions].isin([row['cluster']])]
return clusters_csv,pd.Series(new_bench),cluster_deleted,pd.Series(bench_deleted)


def benchmark_statistics(benchmark_path,clusters_path,save=False):
benchmark_csv=pd.read_csv(benchmark_path,header=None)
clusters_csv=pd.read_csv(clusters_path,sep='\t')
benchmark_clusters_csv=pd.DataFrame()
#d=pd.DataFrame()
for _,v in benchmark_csv.iterrows():
for _,row in clusters_csv.iterrows():
if v.values[0] in row['id']:
benchmark_clusters_csv.loc[row['id'],'id_complex']=v.values[0]
benchmark_clusters_csv.loc[row['id'],'id']=row['id']
benchmark_clusters_csv.loc[row['id'],'cluster']=row[clustering_regions]
benchmark_clusters_csv.loc[row['id'],'size']=(row[clustering_regions]==clusters_csv[clustering_regions]).sum()

benchmark_clusters_csv['size']=benchmark_clusters_csv['size'].astype(int)
if save:
benchmark_clusters_csv.sort_values('cluster').to_csv(f'benchmark_clusters_{clustering_regions}.tsv',sep='\t',index=False)
return benchmark_clusters_csv

def is_in_benchmark(clusters_path,benchmark_clusters_csv,save=False):
clusters_csv=pd.read_csv(clusters_path,sep='\t')
clusters=clusters_csv[clustering_regions].unique()
in_benchmark=pd.DataFrame()

for cl in clusters:
ids_of_clusters=clusters_csv.loc[clusters_csv[clustering_regions]==cl]
if set(ids_of_clusters[clustering_regions].tolist()).intersection(set(benchmark_clusters_csv['cluster'])):
in_benchmark.loc[cl,'in_benchmark']=True
in_benchmark.loc[cl,'size']=benchmark_clusters_csv.loc[benchmark_clusters_csv['cluster']==cl,'size'].iloc[0]
else:
in_benchmark.loc[cl,'in_benchmark']=False
in_benchmark.loc[cl,'size']=0
in_benchmark['0']=in_benchmark.index
in_benchmark['size']=in_benchmark['size'].astype(int)
in_benchmark=in_benchmark[['0','in_benchmark','size']].dropna()

if save:
in_benchmark.sort_values('0').to_csv(f'in_benchmark_{clustering_regions}.tsv',sep='\t',index=False)
return in_benchmark


if __name__=='__main__':
parser=ArgumentParser('Create fastas for clustering via MMseqs')
parser.add_argument('--benchmark_path', type=str,default=Path('rabd_benchmark.txt'))
parser.add_argument('--clusters_path', type=str,default=Path("MMseq/joined_clusters.tsv"))
parser.add_argument('--clustering_regions',default='renamed_clusterRes_0.5_DB_CDR_H3.fasta_cluster')
parser.add_argument('--threshhold',default=30)
parser.add_argument('--sample_dir',default=Path('train_val_test'),type=Path)

# parser.add_argument('--identity',default=0.5,type=float)
# parser.add_argument('--result_folder',default='MMseq',type=Path)
args=parser.parse_args()
benchmark_path=args.benchmark_path
clusters_path=args.clusters_path
clustering_regions=args.clustering_regions
t=args.threshhold
sample_dir=args.sample_dir

benchmark_clusters_csv=benchmark_statistics(benchmark_path,clusters_path,True)

benchmark_csv=pd.read_csv(benchmark_path,header=None)
clusters_csv=pd.read_csv(clusters_path,sep='\t')


in_benchmark=is_in_benchmark(clusters_path,benchmark_clusters_csv,True)





train,test,del_train,del_test=filter_sample(clusters_csv,benchmark_clusters_csv,t)

sample_dir.mkdir(exist_ok=True,parents=True)
train[['id',clustering_regions]].to_csv(sample_dir/f'train_and_val_{clustering_regions}.tsv',sep='\t',index=False)
test.to_csv(sample_dir/f'test_{clustering_regions}.tsv',sep='\t',index=False,header=False)
del_train['id'].to_csv(sample_dir/f'deleted_train_and_val_{clustering_regions}.tsv',sep='\t',index=False,header=False)
del_test.to_csv(sample_dir/f'deleted_test_{clustering_regions}.tsv',sep='\t',index=False,header=False)

s=set(train['id'].values.tolist())|set(test.values.tolist())|set(del_train['id'].values.tolist())|set(del_test.values.tolist())
s=set(s)
assert len(s)==pd.read_csv(clusters_path,sep='\t').shape[0]
24 changes: 24 additions & 0 deletions data/create_clusters_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from pathlib import Path
import pandas as pd
from argparse import ArgumentParser

if __name__=='__main__':
parser=ArgumentParser('Join tsvs')
parser.add_argument('--tsvs',type=Path,nargs='+')
parser.add_argument('--joined_tsv',type=Path,default=Path('MMseq/joined_clusters.tsv'))
args=parser.parse_args()
tsvs=args.tsvs
joined_tsv=args.joined_tsv
df_list=[pd.read_csv(t,sep='\t',header=None,names=[t.stem,'id']) for t in tsvs]
df_joined=df_list[0].copy()
for df in df_list[1:]:
df_joined=df_joined.merge(df,on='id',how='outer',suffixes=('_x', '_y'))
cols=['id']
for col in df_joined.columns:
if col!='id':
cols.append(col)
df_joined=df_joined[cols]
if joined_tsv.exists():
raise ValueError(f"file with path {joined_tsv} already exists, change output_file variable or delete file {joined_tsv}")
else:
df_joined.to_csv(joined_tsv,sep='\t',index=False)
97 changes: 97 additions & 0 deletions data/create_fastas_for_mmseqs2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
from pathlib import Path
from enum import Enum
from tqdm import tqdm
import logging
import traceback
from pathlib import Path
from typing import Iterable, cast
from argparse import ArgumentParser
import pandas as pd
import subprocess

from proteinlib.structure.antibody_antigen_complex import AntibodyAntigenComplex, NumberingScheme

class Regions(Enum):
FR_L1=0
CDR_L1=1
FR_L2=2
CDR_L2=3
FR_L3=4
CDR_L3=5
FR_L4=6

FR_H1=7
CDR_H1=8
FR_H2=9
CDR_H2=10
FR_H3=11
CDR_H3=12
FR_H4=13

if __name__=='__main__':
logging.basicConfig(filename='creating_fastas.log', level=logging.DEBUG)

parser=ArgumentParser('Create fastas for clustering via MMseqs')
parser.add_argument('--summary_csv', type=str,default=Path('/mnt/sabdab/summary.csv'))
parser.add_argument('--chothia_subdir', type=str,default=Path("/mnt/sabdab/chothia"))
parser.add_argument('--regions',default=['CDR_L1','CDR_L2','CDR_L3','CDR_H1','CDR_H2','CDR_H3'],nargs='*')
parser.add_argument('--identity',default=0.5,type=float)
parser.add_argument('--result_folder',default='MMseq',type=Path)
args=parser.parse_args()

summary_csv=args.summary_csv
chothia_subdir=args.chothia_subdir
regions_list=args.regions

# only protein and peptide antigens
df = (
pd.read_csv(
summary_csv,
sep="\t",
usecols=["pdb", "Hchain", "Lchain", "antigen_chain", "antigen_type"],
)
.query("antigen_type in ('protein', 'peptide', 'protein | peptide', 'peptide | protein')")
.dropna()
.reset_index()
)
print(f"Summary records: {df.shape[0]}")
d={}
for i, row in tqdm(cast(Iterable[tuple[int, pd.Series]], df.iterrows()),total=df.shape[0]):
uid = f'{row["pdb"]}_{row["Hchain"]}+{row["Lchain"]}-{row["antigen_chain"]}'
try:
sequences: list[tuple[str, str]] = []
# row = cast(pd.Series, row)
antigen_chains = tuple(map(lambda s: s.strip(), str(row["antigen_chain"]).split(" | ")))
ab_complex = AntibodyAntigenComplex.from_pdb(
pdb=chothia_subdir / f"{row['pdb']}.pdb",
heavy_chain_id=str(row["Hchain"]),
light_chain_id=str(row["Lchain"]),
antigen_chain_ids=antigen_chains,
numbering=NumberingScheme.CHOTHIA,
)
# get regions
vh_regions = [region.sequence for region in list(ab_complex.antibody.heavy_chain.regions)]
vl_regions = [region.sequence for region in list(ab_complex.antibody.light_chain.regions)]
all_regions = [*vl_regions, *vh_regions]

for region in regions_list:
sequences.append(all_regions[getattr(Regions,region).value])
d[uid]=''.join(sequences)
except FileNotFoundError:
continue
except Exception as err:
logging.warning(f"In complex {uid}: {traceback.format_exception(err)}")
print(f"In complex {uid}: {traceback.format_exception(err)}")
continue
fasta_str=''
for k,v in d.items():
fasta_str='\n'.join([fasta_str,f">{k}\n{v}"])

fasta_str=fasta_str.strip()
dir=Path(args.result_folder)/'__'.join(args.regions)
dir.mkdir(exist_ok=True,parents=True)
fasta_file=dir/f'DB_{"_".join(regions_list)}.fasta'
output_path=Path(fasta_file)
output_path.write_text(fasta_str)

subprocess.run(['mmseqs', 'easy-cluster', f'{fasta_file}', str(dir/f'clusterRes_{args.identity}_{fasta_file.name}'), 'tmp', '--min-seq-id', f'{args.identity}', '-c', '0.8', '--cov-mode', '1'])
28 changes: 28 additions & 0 deletions data/mark_clusters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import pandas as pd
from pathlib import Path

if __name__=='__main__':
benchmark_path=Path('benchmark_clusters.tsv')
clusters_path=Path('/mnt/sabdab/clusters.tsv')

clusters_csv=pd.read_csv(clusters_path,sep='\t')#.dropna()
benchmark_csv=pd.read_csv(benchmark_path,sep='\t')

clusters=clusters_csv['renamed_clusterRes_0.5_DB_CDR_H3_CDR_H2_CDR_H1_CDR_L1_CDR_L2_CDR_L3.fasta_cluster'].unique()
clusters_size=clusters_csv['renamed_clusterRes_0.5_DB_CDR_H3_CDR_H2_CDR_H1_CDR_L1_CDR_L2_CDR_L3.fasta_cluster'].value_counts()
in_benchmark=pd.DataFrame()

for cl in clusters:
ids_of_clusters=clusters_csv.loc[clusters_csv['renamed_clusterRes_0.5_DB_CDR_H3_CDR_H2_CDR_H1_CDR_L1_CDR_L2_CDR_L3.fasta_cluster']==cl]
if set(ids_of_clusters['renamed_clusterRes_0.5_DB_CDR_H3_CDR_H2_CDR_H1_CDR_L1_CDR_L2_CDR_L3.fasta_cluster'].tolist()).intersection(set(benchmark_csv['cluster'])):

in_benchmark.loc[cl,'in_benchmark']=True#ids_of_clusters['renamed_clusterRes_0.5_DB_CDR_H3_CDR_H2_CDR_H1_CDR_L1_CDR_L2_CDR_L3.fasta_cluster'].isin(benchmark_csv['cluster'])#.any()
in_benchmark.loc[cl,'size']=benchmark_csv.loc[benchmark_csv['cluster']==cl,'size'].iloc[0]
else:
in_benchmark.loc[cl,'in_benchmark']=False
in_benchmark.loc[cl,'size']=0
#in_benchmark=pd.concat([in_benchmark,ser])
in_benchmark['0']=in_benchmark.index
in_benchmark['size']=in_benchmark['size'].astype(int)
in_benchmark=in_benchmark[['0','in_benchmark','size']]
in_benchmark.sort_values('0').to_csv('in_benchmark.tsv',sep='\t',index=False)
57 changes: 57 additions & 0 deletions data/plot_by_threshhold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import pandas as pd
from pathlib import Path
from tqdm import tqdm

def filter_sample(clusters_csv,benchmark_clusters,t):
#ids_bench=in_benchmark.dropna().loc[in_benchmark['in_benchmark'],'0'].to_list()
cl_bench=benchmark_clusters.loc[:,'cluster'].value_counts()
new_bench=benchmark_clusters['id'].tolist()
for _,row in benchmark_clusters.iterrows():
if row['size']/cl_bench[row['cluster']]>t:
new_bench.remove(row['id'])
else:
clusters_csv=clusters_csv[~clusters_csv[clustering_regions].isin([row['cluster']])]
return clusters_csv,pd.Series(new_bench)


if __name__=='__main__':
# %%
import pandas as pd
from pathlib import Path
from tqdm import tqdm

def filter_sample(clusters_csv,benchmark_clusters,t):
#ids_bench=in_benchmark.dropna().loc[in_benchmark['in_benchmark'],'0'].to_list()
cl_bench=benchmark_clusters.loc[:,'cluster'].value_counts()
new_bench=benchmark_clusters['id'].tolist()
for _,row in benchmark_clusters.iterrows():
if row['size']/cl_bench[row['cluster']]>t:
new_bench.remove(row['id'])
else:
clusters_csv=clusters_csv[~clusters_csv[clustering_regions].isin([row['cluster']])]
return clusters_csv,pd.Series(new_bench)

clusters_path=Path('/mnt/sabdab/clusters_with_1w72.tsv')
clusters_csv=pd.read_csv(clusters_path,sep='\t')#.dropna()

clustering_regions='renamed_clusterRes_0.5_DB_CDR_H1_CDR_H2_CDR_H3_CDR_L1_CDR_L2_CDR_L3.fasta_cluster'
clustering_regions='renamed_clusterRes_0.5_DB_CDR_H3.fasta_cluster'
benchmark_path=f'/data/user/shapoval/ProteinMPNN/benchmark_clusters_{clustering_regions}.tsv'
benchmark_clusters_csv=pd.read_csv(benchmark_path,sep='\t')

T=range(200)
train_size,test_size=[],[]
for t in tqdm(T):
train,test=filter_sample(clusters_csv,benchmark_clusters_csv,t)
train_size.append(train.shape[0])
test_size.append(test.shape[0])

import matplotlib.pyplot as plt

plt.plot(train_size,test_size)
#plt.title(clustering_regions[19:67])
plt.title(clustering_regions[19:])
plt.grid()
plt.xlabel('train size')
plt.ylabel('test size')
plt.show()
18 changes: 18 additions & 0 deletions data/remove_benchmark_from_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import pandas as pd

if __name__=='__main__':
in_bench=pd.read_csv('in_benchmark.tsv',sep='\t').dropna()
clusters_csv=pd.read_csv('/mnt/sabdab/clusters.tsv',sep='\t').dropna()
df=pd.DataFrame()
l=[]
# for _,row in in_bench.iterrows():
# #print(row['in_benchmark'])
# if row['in_benchmark']:
# l.append(row[0])
ids_bench=in_bench.loc[in_bench['in_benchmark'],'0'].to_list()

df_del=clusters_csv[~clusters_csv['renamed_clusterRes_0.5_DB_CDR_H3_CDR_H2_CDR_H1_CDR_L1_CDR_L2_CDR_L3.fasta_cluster'].isin(ids_bench)]
# for _,row in clusters_csv.iterrows():
# if row['renamed_clusterRes_0.5_DB_CDR_H3_CDR_H2_CDR_H1_CDR_L1_CDR_L2_CDR_L3.fasta_cluster'] not in ids_bench:
# df=pd.concat([df,row.to_list()],ignore_index=True)
df_del.to_csv('filtered_sample.tsv',sep='\t',index=False)
44 changes: 44 additions & 0 deletions data/replace_cluster_names.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pandas as pd
from pathlib import Path
from argparse import ArgumentParser

def df_to_dict(df):
d={}
for row in df.iterrows():
k,v=row[1].values
if k not in d:
d[k]=[v]
else:
d[k].append(v)
return d

def get_sorted_names(d):
d_sort={}
for k,v in d.items():
d_sort[k]=min(v)
return d_sort

def rename_clusters(df,d_sorted):
df_new=pd.DataFrame()
for row in df.iterrows():
k,_=row[1].values
#df_new=pd.concat([df_new,pd.DataFrame({'0':d_sorted[k],'1':row[1].values[1]})])
df_new=pd.concat([df_new,pd.DataFrame({0:d_sorted[k],1:row[1].values[1]},index=[0])],ignore_index=True)
return df_new

if __name__=='__main__':
parser=ArgumentParser('Rename clusters')
parser.add_argument('--cluster_tsv',type=Path)
args=parser.parse_args()
cluster_csv=args.cluster_tsv
renamed_csv=Path(f'{str(cluster_csv.parent)}/renamed_{cluster_csv.name}')
df=pd.read_csv(cluster_csv,sep="\t",header=None)
d=df_to_dict(df)
d_sorted=get_sorted_names(d)
df_clusters_renamed=rename_clusters(df,d_sorted)
assert df_clusters_renamed.shape[0]==df.shape[0]
if renamed_csv.exists():
raise ValueError(f"file with path {renamed_csv} already exists, change output_file variable or delete file {renamed_csv}")
else:
df_clusters_renamed.to_csv(renamed_csv,sep='\t',header=False,index=False)
print(renamed_csv)
Loading