-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
90 lines (68 loc) · 2.72 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import json
import numpy as np
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from .utils.common import ESM_TOKENIZER
class StructureDataset(Dataset):
def __init__(self, data_config):
super(StructureDataset, self).__init__()
self.samples = []
with open(data_config.path, 'r') as data_f:
lines = data_f.readlines()
for line in tqdm(lines):
sample = json.loads(line)
if len(sample['coords']) == len(sample['seq']):
sample['coords'] = [residue[:3] for residue in sample['coords']] #excluding oxygen atom on main-chain
self.samples.append(sample)
def __getitem__(self, index):
return self.samples[index]
def __len__(self):
return len(self.samples)
@staticmethod
#scibert max_seq_length=512
#llama support max sequence 4096
def featurize(batch, max_seq_length=512):
B = len(batch)
L_max = max([len(b['seq']) for b in batch])
X = np.zeros([B, L_max, 3, 3])
S = np.zeros([B, L_max], dtype=np.int32)
for i, sample in enumerate(batch):
l = len(sample['seq'])
x = sample['coords']
x_pad = np.pad(x, [[0, L_max-l], [0,0], [0,0]], 'constant', constant_values=(np.nan, )) #[atom, 3, 3]
X[i,:,:,:] = x_pad
indices = np.array(ESM_TOKENIZER.encode(sample['seq'], add_special_tokens=False))
S[i, :l] = indices
mask = np.isfinite(np.sum(X,(2,3))).astype(np.int32)
numbers = np.sum(mask, axis=1).astype(np.int32)
S_new = np.zeros_like(S)
X_new = np.zeros_like(X) + np.nan
for i, n in enumerate(numbers):
X_new[i,:n,::] = X[i][mask[i]==1]
S_new[i,:n] = S[i][mask[i]==1]
X = X_new
S = S_new
isnan = np.isnan(X)
mask = np.isfinite(np.sum(X,(2,3))).astype(np.int32)
X[isnan] = 0.0
L = S.shape[1]
if L > max_seq_length:
X = X[:, :max_seq_length, ...]
S = S[:, :max_seq_length]
mask = mask[:, :max_seq_length]
return {
"name": [sample['name'] for sample in batch],
"X": torch.from_numpy(X).to(torch.float32),
"S": torch.from_numpy(S).to(torch.long),
"mask": torch.from_numpy(mask).to(torch.bool),
}
if __name__ == '__main__':
import yaml
from easydict import EasyDict
#import ipdb; ipdb.set_trace()
with open('./entex/configs/entex_test.yml', 'r') as f:
config = EasyDict(yaml.safe_load(f))
dataset_config = config.dataset
dataset = StructureDataset(dataset_config)
StructureDataset.featurize([dataset[0], dataset[1]])