-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdim_reducer.py
97 lines (63 loc) · 2.72 KB
/
dim_reducer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import logging
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np
import torch
class Dim_Reducer(object):
def __init__(self, device):
self.device = device
def unsupervised_reduce(self, reduce_method="tSNE",
model=None, batch_data=None, data_loader=None, num_points=1000):
if reduce_method == "tSNE":
__reduce_method = TSNE(n_components=2, random_state=33)
other_params = {}
else:
raise NotImplementedError
if model is not None:
data_input, labels = self.get_features(model=model, batch_data=batch_data,
data_loader=data_loader, num_points=num_points)
data_tsne = __reduce_method.fit_transform(data_input, **other_params)
labels = labels.numpy()
model.to("cpu")
else:
if batch_data is not None:
data_input, labels = batch_data
else:
raise NotImplementedError
data_tsne = __reduce_method.fit_transform(data_input, **other_params)
labels = labels.numpy()
return data_tsne, labels
def get_features(self, model=None, batch_data=None, data_loader=None, num_points=1000):
if model is not None:
model.eval()
model = model.to(self.device)
with torch.no_grad():
if batch_data is not None:
data, labels = batch_data
data = data.to(self.device)
output, feat = model(data)
feat = feat.to('cpu')
elif data_loader is not None:
feat_list = []
labels_list = []
loaded_num_points = 0
for i, batch_data in enumerate(data_loader):
data, labels = batch_data
data = data.to(self.device)
output, feat = model(data)
feat_list.append(feat)
labels_list.append(labels)
loaded_num_points += data.shape[0]
if num_points < loaded_num_points:
break
feat = torch.cat(feat_list, dim=0)[:num_points].to('cpu')
labels = torch.cat(labels_list, dim=0)[:num_points]
else:
raise NotImplementedError
logging.info(f"feat.shape: {feat.shape}")
data_input = feat
# data_tsne = __reduce_method.fit_transform(data_input, **other_params)
model.to("cpu")
else:
raise NotImplementedError
return data_input, labels