-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
71 lines (59 loc) · 3.25 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import math
import torch
from torch import nn
class MLP(nn.Module):
"""docstring for MLP"""
def __init__(self, in_size: int, out_size: int, hidden_size: int, num_layers: int):
super(MLP, self).__init__()
self.activation = nn.LeakyReLU()
layers = [nn.Linear(in_size, hidden_size), self.activation]
for i in range(num_layers - 2):
layers.append(nn.Linear(hidden_size, hidden_size))
layers.append(self.activation)
layers.append(nn.Linear(hidden_size, out_size))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
class WindowedMLP(nn.Module):
"""docstring for WindowedMLP"""
def __init__(self, total_size: int, window_size: int, num_layers: int, num_classes: int, num_super_classes: int):
super(WindowedMLP, self).__init__()
self.total_size = total_size
self.window_size = window_size
self.num_classes = num_classes
self.num_super_classes = num_super_classes
self.mlps = nn.ModuleList([])
num_models = math.ceil(self.total_size / self.window_size)
for i in range(num_models):
if i == num_models - 1:
window_size = self.total_size % self.window_size
else:
window_size = self.window_size
mlp = MLP(window_size + self.num_classes + self.num_super_classes, window_size, window_size // 2, num_layers)
self.mlps.append(mlp)
self.overlaping_mlps = nn.ModuleList([])
num_models = math.ceil((self.total_size - self.window_size // 2) / self.window_size) + 1
for i in range(num_models):
if i == 0:
window_size = self.window_size // 2
elif i == num_models - 1:
window_size = (self.total_size - self.window_size // 2) % self.window_size
else:
window_size = self.window_size
mlp = MLP(window_size + self.num_classes + self.num_super_classes, window_size, window_size // 2, num_layers)
self.overlaping_mlps.append(mlp)
def forward(self, genotypes, labels, super_labels):
one_hot_label = nn.functional.one_hot(labels, self.num_classes)
one_hot_super_label = nn.functional.one_hot(super_labels, self.num_super_classes)
reconstructed_genotypes = []
for genotype_i, mlp in zip(genotypes.split(self.window_size, 1), self.mlps):
reconstructed_genotype = mlp(torch.cat([genotype_i, one_hot_label, one_hot_super_label], 1))
reconstructed_genotypes.append(reconstructed_genotype)
reconstructed_genotypes = torch.cat(reconstructed_genotypes, 1)
reconstructed_overlaping_genotypes = []
for genotype_i, mlp in zip((genotypes[:, :self.window_size // 2],) + genotypes[:, self.window_size // 2:].split(self.window_size, 1), self.overlaping_mlps):
reconstructed_genotype = mlp(torch.cat([genotype_i, one_hot_label, one_hot_super_label], 1))
reconstructed_overlaping_genotypes.append(reconstructed_genotype)
reconstructed_overlaping_genotypes = torch.cat(reconstructed_overlaping_genotypes, 1)
reconstructed_genotypes = reconstructed_genotypes + reconstructed_overlaping_genotypes
return reconstructed_genotypes