-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy patheval_probe_timm.py
106 lines (95 loc) · 4.35 KB
/
eval_probe_timm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
from argparse import ArgumentParser
import torch
import torch.nn as nn
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models import vit_base_patch16_224, vit_large_patch16_224, vit_huge_patch14_224
from torch.utils.data import DataLoader
from torchmetrics.functional.classification import multiclass_accuracy
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, Resize, CenterCrop, InterpolationMode, Normalize, ToTensor
from tqdm import tqdm
def parse_args():
parser = ArgumentParser()
parser.add_argument("--root", type=str, default="/local00/bioinf/imagenet1k/val")
parser.add_argument("--encoder", type=str, required=True)
parser.add_argument("--head", type=str, required=True)
parser.add_argument("--device", type=int, required=True)
parser.add_argument("--precision", type=str, default="bfloat16", choices=["float32", "float16", "bfloat16"])
parser.add_argument("--pooling", type=str, default="token", choices=["token", "avg"])
return vars(parser.parse_args())
def main(root, encoder, head, device, precision, pooling):
print(f"initialize dataset ({root})")
os.environ["CUDA_VISIBLE_DEVICES"] = str(device)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset = ImageFolder(
root=root,
transform=Compose([
Resize(size=256, interpolation=InterpolationMode.BICUBIC),
CenterCrop(size=224),
ToTensor(),
Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
]),
)
assert len(dataset.classes) == 1000
print(f"initialize encoder ({encoder})")
encoder_sd = torch.load(encoder, map_location=torch.device("cpu"))
if "state_dict" in encoder_sd:
encoder_sd = encoder_sd["state_dict"]
dim = encoder_sd["pos_embed"].shape[2]
if dim == 768:
model = vit_base_patch16_224(use_fc_norm=False, global_pool=pooling)
elif dim == 1024:
model = vit_large_patch16_224(use_fc_norm=False, global_pool=pooling)
elif dim == 1280:
patch_size = encoder_sd["patch_embed.proj.weight"].shape[2]
if patch_size == 16:
model = vit_huge_patch14_224(use_fc_norm=False, global_pool=pooling, patch_size=16)
elif patch_size == 14:
model = vit_huge_patch14_224(use_fc_norm=False, global_pool=pooling)
else:
raise NotImplementedError
else:
raise NotImplementedError
# timm ViT has a zero vector for the pos_embed of cls token
encoder_sd["pos_embed"] = torch.concat([torch.zeros(1, 1, model.embed_dim), encoder_sd["pos_embed"]], dim=1)
print(f"initialize head ({head})")
head_sd = torch.load(head, map_location=torch.device("cpu"))
if "state_dict" in head_sd:
head_sd = head_sd["state_dict"]
# patch head (probing heads use a non-affine batchnorm before the linear layer)
if "layer.1.running_mean" in head_sd:
model.head = nn.Sequential(
nn.BatchNorm1d(num_features=model.embed_dim, affine=False),
nn.Linear(model.embed_dim, model.num_classes),
)
encoder_sd["head.0.running_mean"] = head_sd["layer.1.running_mean"]
encoder_sd["head.0.running_var"] = head_sd["layer.1.running_var"]
encoder_sd["head.0.num_batches_tracked"] = head_sd["layer.1.num_batches_tracked"]
encoder_sd["head.1.weight"] = head_sd["layer.2.weight"]
encoder_sd["head.1.bias"] = head_sd["layer.2.bias"]
else:
encoder_sd["head.weight"] = head_sd["layer.2.weight"]
encoder_sd["head.bias"] = head_sd["layer.2.bias"]
model.load_state_dict(encoder_sd)
model = model.to(device)
model.eval()
print(f"make predictions (precision={precision})")
preds = []
target = []
for x, y in tqdm(DataLoader(dataset, batch_size=256, num_workers=10, pin_memory=True)):
with torch.no_grad():
with torch.autocast(str(device), dtype=getattr(torch, precision)):
preds.append(model(x.to(device)).cpu())
target.append(y.clone())
preds = torch.concat(preds)
target = torch.concat(target)
acc = multiclass_accuracy(
preds=preds.to(device),
target=target.to(device),
num_classes=model.num_classes,
average="micro",
).item()
print(f"accuracy: {acc:.4f}")
if __name__ == "__main__":
main(**parse_args())