-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvalidate.py
37 lines (29 loc) · 1.21 KB
/
validate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch
import numpy as np
from tqdm import tqdm
from utils import AverageMeter
from config import CFG
# Define your other necessary imports here
def valid_fn(valid_loader, model, criterion, device, valid_xyxys, valid_mask_gt):
mask_pred = np.zeros(valid_mask_gt.shape)
mask_count = np.zeros(valid_mask_gt.shape)
model.eval()
losses = AverageMeter()
for step, (images, labels) in tqdm(enumerate(valid_loader), total=len(valid_loader)):
images = images.to(device)
labels = labels.to(device)
batch_size = labels.size(0)
with torch.no_grad():
y_preds = model(images)
loss = criterion(y_preds, labels)
losses.update(loss.item(), batch_size)
# make whole mask
y_preds = torch.sigmoid(y_preds).to('cpu').numpy()
start_idx = step * CFG.valid_batch_size
end_idx = start_idx + batch_size
for i, (x1, y1, x2, y2) in enumerate(valid_xyxys[start_idx:end_idx]):
mask_pred[y1:y2, x1:x2] += y_preds[i].squeeze(0)
mask_count[y1:y2, x1:x2] += np.ones((CFG.tile_size, CFG.tile_size))
print(f'mask_count_min: {mask_count.min()}')
mask_pred /= mask_count
return losses.avg, mask_pred