-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathssl_eval.py
96 lines (83 loc) · 2.32 KB
/
ssl_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
from torchinfo import summary
from trainer import Trainer
import os
def linear_eval_backbone(
curr_epoch,
epochs,
backbone,
in_features,
out_features,
dataloader,
batch_size,
local_rank,
paths,
args,
general_name,
input_size: int = 224,
verbose: bool = False,
dropout: bool = False
):
# Define your model.
model = torch.nn.Sequential(
backbone,
torch.nn.Flatten(),
)
# Add dropout layer if the dropout argument is passed.
if dropout:
model.add_module('dropout', torch.nn.Dropout(p=0.2, inplace=True))
print(f'Dropout layer added: {model.dropout}')
else:
print('No dropout layer')
# Add the final linear layer
model.add_module(
'linear',
torch.nn.Linear(in_features=in_features,
out_features=out_features,
bias=True)
)
# Get the number of input features to the layer.
print(f'New final fully-connected layer: {model[-1]}')
# Freezing all the network if chosen.
for param in model.parameters():
param.requires_grad = False
for param in model[-1].parameters():
param.requires_grad = True
print('Linear probing adjusted')
# Show model structure.
if verbose:
summary(
model,
input_size=(batch_size, 3, input_size, input_size),
device=local_rank
)
# Configure the loss.
loss_fn = torch.nn.CrossEntropyLoss()
print(f'Loss: {loss_fn}')
# Configure the optimizer.
optimizer = torch.optim.SGD(model.parameters(), lr=0.001) #, momentum=0.9)
# Training.
args.task_name = 'multiclass'
args.transfer_learning = 'LP'
args.load_best_hyperparameters = None
trainer = Trainer(
model,
dataloader,
batch_size,
loss_fn,
optimizer,
save_every=10,
snapshot_path=os.path.join(paths['snapshots'], f'head_{general_name}.pt'),
csv_path=os.path.join(paths['csv_results'], f'head_{general_name}_e={curr_epoch}.csv'),
distributed=False,
lightly_train=True,
ray_tune = False,
ignore_ckpts=True
)
config = {
'args': args,
'epochs': epochs,
'accuracy': 'test',
'save_csv': True
}
trainer.train(config)