-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
103 lines (77 loc) · 2.98 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import time
import cv2
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from sklearn.metrics import fbeta_score
from utils import *
import wandb
from config import CFG
from datasets import *
from model import build_model
from train import train_fn
from validate import valid_fn
import gc
wandb.init(project="vesuvius-proj", config=CFG)
Logger = init_logger(log_file=CFG.log_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for fragment_id in [1, 2, 3]:
if CFG.metric_direction == "minimize":
best_score = np.inf
elif CFG.metric_direction == "maximize":
best_score = -1
best_loss = np.inf
Logger.info(f"--------------model number:- {fragment_id}----------------")
train_loader, valid_loader, valid_xyxys = get_data_loader(fragment_id)
valid_mask_gt = cv2.imread(
CFG.comp_dataset_path + f"train/{fragment_id}/inklabels.png", 0
)
valid_mask_gt = valid_mask_gt / 255
pad0 = CFG.tile_size - valid_mask_gt.shape[0] % CFG.tile_size
pad1 = CFG.tile_size - valid_mask_gt.shape[1] % CFG.tile_size
valid_mask_gt = np.pad(valid_mask_gt, [(0, pad0), (0, pad1)], constant_values=0)
model = build_model(CFG, None)
model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=CFG.lr)
scheduler = get_scheduler(CFG, optimizer)
for epoch in range(CFG.epochs):
start_time = time.time()
# Train
avg_loss = train_fn(train_loader, model, optimizer, device)
# Evaluate
avg_val_loss, mask_pred = valid_fn(
valid_loader, model, device, valid_xyxys, valid_mask_gt
)
scheduler.step()
best_dice, best_th = calc_cv(valid_mask_gt, mask_pred)
elapsed = time.time() - start_time
Logger.info(
f"Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f} avg_val_loss: {avg_val_loss:.4f} time: {elapsed:.0f}s"
)
Logger.info(f"Epoch {epoch+1} - avgScore: {best_dice:.4f}")
wandb.log(
{
"Epoch": epoch + 1,
"avg_train_loss": avg_loss,
"avg_val_loss": avg_val_loss,
"time_elapsed": elapsed,
"SCORE": best_dice,
}
)
if CFG.metric_direction == "minimize":
update_best = best_dice < best_score
elif CFG.metric_direction == "maximize":
update_best = best_dice > best_score
gc_collect()
if update_best:
print("------UPDATE BESTTT------")
torch.save(
{"model": model.state_dict()},
f"{CFG.model_dir}/{CFG.backbone}_fold_{fragment_id}_best.pth",
)
best_loss = avg_val_loss
best_score = best_dice
Logger.info(f"Epoch {epoch+1} - Save Best Score: {best_score:.4f} Model")
Logger.info(f"Epoch {epoch+1} - Save Best Loss: {best_loss:.4f} Model")
wandb.finish()