-
Notifications
You must be signed in to change notification settings - Fork 11
/
main.py
109 lines (92 loc) · 4.49 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import sys
import faiss
import torch
import numpy as np
from tqdm import tqdm
from pathlib import Path
from loguru import logger
from datetime import datetime
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Subset
import vpr_models
import parser
import visualizations
from test_dataset import TestDataset
def main(args):
start_time = datetime.now()
logger.remove() # Remove possibly previously existing loggers
log_dir = Path("logs") / args.log_dir / start_time.strftime("%Y-%m-%d_%H-%M-%S")
logger.add(sys.stdout, colorize=True, format="<green>{time:%Y-%m-%d %H:%M:%S}</green> {message}", level="INFO")
logger.add(log_dir / "info.log", format="<green>{time:%Y-%m-%d %H:%M:%S}</green> {message}", level="INFO")
logger.add(log_dir / "debug.log", level="DEBUG")
logger.info(" ".join(sys.argv))
logger.info(f"Arguments: {args}")
logger.info(
f"Testing with {args.method} with a {args.backbone} backbone and descriptors dimension {args.descriptors_dimension}"
)
logger.info(f"The outputs are being saved in {log_dir}")
model = vpr_models.get_model(args.method, args.backbone, args.descriptors_dimension)
model = model.eval().to(args.device)
test_ds = TestDataset(
args.database_folder,
args.queries_folder,
positive_dist_threshold=args.positive_dist_threshold,
image_size=args.image_size,
use_labels=args.use_labels,
)
logger.info(f"Testing on {test_ds}")
with torch.inference_mode():
logger.debug("Extracting database descriptors for evaluation/testing")
database_subset_ds = Subset(test_ds, list(range(test_ds.num_database)))
database_dataloader = DataLoader(
dataset=database_subset_ds, num_workers=args.num_workers, batch_size=args.batch_size
)
all_descriptors = np.empty((len(test_ds), args.descriptors_dimension), dtype="float32")
for images, indices in tqdm(database_dataloader):
descriptors = model(images.to(args.device))
descriptors = descriptors.cpu().numpy()
all_descriptors[indices.numpy(), :] = descriptors
logger.debug("Extracting queries descriptors for evaluation/testing using batch size 1")
queries_subset_ds = Subset(
test_ds, list(range(test_ds.num_database, test_ds.num_database + test_ds.num_queries))
)
queries_dataloader = DataLoader(dataset=queries_subset_ds, num_workers=args.num_workers, batch_size=1)
for images, indices in tqdm(queries_dataloader):
descriptors = model(images.to(args.device))
descriptors = descriptors.cpu().numpy()
all_descriptors[indices.numpy(), :] = descriptors
queries_descriptors = all_descriptors[test_ds.num_database :]
database_descriptors = all_descriptors[: test_ds.num_database]
if args.save_descriptors:
logger.info(f"Saving the descriptors in {log_dir}")
np.save(log_dir / "queries_descriptors.npy", queries_descriptors)
np.save(log_dir / "database_descriptors.npy", database_descriptors)
# Use a kNN to find predictions
faiss_index = faiss.IndexFlatL2(args.descriptors_dimension)
faiss_index.add(database_descriptors)
del database_descriptors, all_descriptors
logger.debug("Calculating recalls")
_, predictions = faiss_index.search(queries_descriptors, max(args.recall_values))
# For each query, check if the predictions are correct
if args.use_labels:
positives_per_query = test_ds.get_positives()
recalls = np.zeros(len(args.recall_values))
for query_index, preds in enumerate(predictions):
for i, n in enumerate(args.recall_values):
if np.any(np.in1d(preds[:n], positives_per_query[query_index])):
recalls[i:] += 1
break
# Divide by num_queries and multiply by 100, so the recalls are in percentages
recalls = recalls / test_ds.num_queries * 100
recalls_str = ", ".join([f"R@{val}: {rec:.1f}" for val, rec in zip(args.recall_values, recalls)])
logger.info(recalls_str)
# Save visualizations of predictions
if args.num_preds_to_save != 0:
logger.info("Saving final predictions")
# For each query save num_preds_to_save predictions
visualizations.save_preds(
predictions[:, : args.num_preds_to_save], test_ds, log_dir, args.save_only_wrong_preds, args.use_labels
)
if __name__ == "__main__":
args = parser.parse_arguments()
main(args)