-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patheval_retrieval.py
102 lines (86 loc) · 4.32 KB
/
eval_retrieval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import argparse
import json
import logging
import os
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import paths
from datasets.voc12 import VOCDataset, create_image_only_transforms
from models.blip.blip_config import BlipConfig
from models.blip.blip_retrieval import BlipRetrieval
from models.preprocessing.preprocess import get_processors
from utils.gpu_profiler import GPUProfiler
from utils.logger import setup_logger
from utils.misc import get_timestamp_for_filename
from utils.retrieval_metric import evaluate_retrieval
profiler = GPUProfiler()
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-v", "--verbose", help="Increase verbosity", action="store_true")
parser.add_argument("--cpu", help="Run on CPU instead of GPU", action="store_true")
parser.add_argument("--batch_size", help="Batch size", type=int, default=16)
parser.add_argument("--num_workers", help="Dataloader workers (0 for debugging, higher for "
"faster dataloading)", type=int, default=0)
parser.add_argument("--output_dir", help="Output directory",
type=str, default="outputs/eval_retrieval")
parser.add_argument("--eval_ckpt", type=str, default=None,
help="Path to checkpoint to evaluate")
args = parser.parse_args()
eval_without_args(verbose=args.verbose, cpu=args.cpu, batch_size=args.batch_size, num_workers=args.num_workers,
output_dir=args.output_dir, eval_ckpt=args.eval_ckpt)
def eval_without_args(verbose=False, cpu=False, eval_ckpt=None, output_dir="outputs/eval_retrieval", batch_size=16,
num_workers=0):
setup_logger(level=logging.DEBUG if verbose else logging.INFO)
torch.manual_seed(42)
use_cuda = not cpu and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
logging.info(f"Running on device: {device}, cuda available: {torch.cuda.is_available()}")
model_cfg = BlipConfig()
model = BlipRetrieval.from_config(model_cfg)
logging.info(f"Created model {type(model).__name__} with {model.show_n_params()} parameters.")
model.load_checkpoint(Path(paths.CV_PATH_CKPT) / "blip_model_base.pth")
eval_ckpt = eval_ckpt
if eval_ckpt is None:
eval_ckpt = Path(paths.CV_PATH_CKPT) / "blip_model_retrieval_head.pth"
model.load_retrieval_head(eval_ckpt)
model = model.to(device)
model.eval()
logging.info(f"GPU/RAM status: {profiler.profile_to_str()}")
output_dir = Path(output_dir) / get_timestamp_for_filename()
os.makedirs(output_dir, exist_ok=True)
logging.info(f"Output dir: {output_dir}")
# setup dataloader
voc_path = Path(paths.CV_PATH_VOC)
logging.info(f"Load dataset from {voc_path}")
vis_processor_val, text_processor_val = get_processors(model_cfg, mode="eval")
val_dataset = VOCDataset(
voc_path, voc_path / "ImageSets" / "Segmentation" / "val.txt", load_captions=True,
transforms=create_image_only_transforms(vis_processor_val))
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
num_workers=num_workers, drop_last=False)
image_feats, text_feats = validate(model, val_dataloader, device)
val_results = evaluate_retrieval(image_feats, text_feats)
logging.info(f"Validation results: {val_results}")
logging.info(f"Max GPU memory allocated: {torch.cuda.max_memory_allocated() / 1e6:.03f}M")
with (output_dir / f"scores.json").open("w", encoding="utf-8") as f:
json.dump(val_results, f, indent=2)
def validate(model, dataloader, device):
model.eval()
pbar = tqdm(total=len(dataloader), desc=f"Generating retrieval features for eval")
image_feats, text_feats = [], []
for i, batch in enumerate(dataloader):
image = batch["image"].to(device)
caption = batch["caption"]
with torch.no_grad():
image_feat, text_feat = model(image, caption)
text_feats.append(text_feat.detach().cpu().numpy())
image_feats.append(image_feat.detach().cpu().numpy())
pbar.update(1)
image_feats = np.concatenate(image_feats, axis=0)
text_feats = np.concatenate(text_feats, axis=0)
return image_feats, text_feats
if __name__ == "__main__":
main()