-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathhubconf.py
75 lines (66 loc) · 2.34 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from functools import partial
import torch
from hub.postnorm_vit import PostnormVit
from hub.prenorm_vit import PrenormVit
dependencies = ["torch", "kappamodules", "einops"]
VIT_CONFIGS = dict(
debug=dict(patch_size=16, dim=16, depth=2, num_heads=2),
l16=dict(patch_size=16, dim=1024, depth=24, num_heads=16),
h14=dict(patch_size=14, dim=1280, depth=32, num_heads=16),
twob14=dict(patch_size=14, dim=2560, depth=24, num_heads=32),
)
CONFIS = {
"debug": dict(
ctor=PrenormVit,
ctor_kwargs=VIT_CONFIGS["debug"],
url=None,
),
"mae_refined_l16": dict(
ctor=PrenormVit,
ctor_kwargs=VIT_CONFIGS["l16"],
url="https://ml.jku.at/research/mimrefiner/download/maerefined_l16.th",
),
"dbot_refined_l16": dict(
ctor=PrenormVit,
ctor_kwargs=VIT_CONFIGS["l16"],
url="https://ml.jku.at/research/mimrefiner/download/dbotrefined_l16.th",
),
"crossmae_refined_l16": dict(
ctor=PrenormVit,
ctor_kwargs=VIT_CONFIGS["l16"],
url="https://ml.jku.at/research/mimrefiner/download/crossmaerefined_l16.th",
),
"d2v2_refined_l16": dict(
ctor=PostnormVit,
ctor_kwargs=VIT_CONFIGS["l16"],
url="https://ml.jku.at/research/mimrefiner/download/d2v2refined_l16.th",
),
"mae_refined_h14": dict(
ctor=PrenormVit,
ctor_kwargs=VIT_CONFIGS["h14"],
url="https://ml.jku.at/research/mimrefiner/download/maerefined_h14.th",
),
"dbot_refined_h14": dict(
ctor=PrenormVit,
ctor_kwargs=VIT_CONFIGS["h14"],
url="https://ml.jku.at/research/mimrefiner/download/dbotrefined_h14.th",
),
"d2v2_refined_h14": dict(
ctor=PostnormVit,
ctor_kwargs=VIT_CONFIGS["h14"],
url="https://ml.jku.at/research/mimrefiner/download/d2v2refined_h14.th",
),
"mae_refined_twob14": dict(
ctor=PrenormVit,
ctor_kwargs=VIT_CONFIGS["twob14"],
url="https://ml.jku.at/research/mimrefiner/download/maerefined_twob14.th",
),
}
def load_model(ctor, ctor_kwargs, url, **kwargs):
model = ctor(**ctor_kwargs, **kwargs)
if url is not None:
sd = torch.hub.load_state_dict_from_url(url, map_location="cpu")
model.load_state_dict(sd["state_dict"])
return model
for name, config in CONFIS.items():
globals()[name] = partial(load_model, **config)