-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathpretrain_IJEPA.py
206 lines (171 loc) · 7 KB
/
pretrain_IJEPA.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import numpy as np
import pytorch_lightning as pl
import torch.nn as nn
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from pytorch_lightning.callbacks import (
ModelCheckpoint,
LearningRateMonitor,
ModelSummary,
)
from pytorch_lightning.loggers import WandbLogger
from model import IJEPA_base
'''Dummy Dataset'''
class IJEPADataset(Dataset):
def __init__(self,
dataset_path,
stage='train',
):
super().__init__()
img1 =torch.randn(3, 224, 224)
self.data = img1.repeat(100, 1, 1, 1)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
'''Placeholder for datamodule in pytorch lightning'''
'''
Placeholder for datamodule in pytorch lightning
'''
class D2VDataModule(pl.LightningDataModule):
def __init__(self,
dataset_path,
batch_size=16,
num_workers=4,
pin_memory=True,
shuffle=True
):
super().__init__()
self.dataset_path = dataset_path
self.batch_size = batch_size
self.num_workers = num_workers
self.shuffle = shuffle
def setup(self, stage=None):
self.train_dataset = IJEPADataset(dataset_path=self.dataset_path, stage='train')
self.val_dataset = IJEPADataset(dataset_path=self.dataset_path, stage='val')
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=self.shuffle,
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
shuffle=False,
)
'''
pytorch lightning model
'''
class IJEPA(pl.LightningModule):
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=64,
enc_heads=8,
enc_depth=8,
decoder_depth=6,
lr=1e-6,
weight_decay=0.05,
target_aspect_ratio = (0.75,1.5),
target_scale = (0.15, .2),
context_aspect_ratio = 1,
context_scale = (0.85,1.0),
M = 4, #number of different target blocks
m=0.996, #momentum
m_start_end = (.996, 1.)
):
super().__init__()
self.save_hyperparameters()
#define models
self.model = IJEPA_base(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
enc_depth = enc_depth, num_heads=enc_heads, pred_depth=decoder_depth, M=M)
#define hyperparameters
self.M = M
self.lr = lr
self.weight_decay = weight_decay
self.m = m
self.target_aspect_ratio = target_aspect_ratio
self.target_scale = target_scale
self.context_aspect_ratio = context_aspect_ratio
self.context_scale = context_scale
self.embed_dim = embed_dim
self.patch_size = patch_size
self.num_tokens = (img_size // patch_size) ** 2
self.m_start_end = m_start_end
#define loss
self.criterion = nn.MSELoss()
def forward(self, x, target_aspect_ratio, target_scale, context_aspect_ratio, context_scale):
return self.model(x, target_aspect_ratio, target_scale, context_aspect_ratio, context_scale)
'''Update momentum for teacher encoder'''
def update_momentum(self, m):
student_model = self.model.student_encoder.eval()
teacher_model = self.model.teacher_encoder.eval()
with torch.no_grad():
for student_param, teacher_param in zip(student_model.parameters(), teacher_model.parameters()):
teacher_param.data.mul_(other=m).add_(other=student_param.data, alpha=1 - m)
def training_step(self, batch, batch_idx):
x = batch
#generate random target and context aspect ratio and scale
target_aspect_ratio = np.random.uniform(self.target_aspect_ratio[0], self.target_aspect_ratio[1])
target_scale = np.random.uniform(self.target_scale[0], self.target_scale[1])
context_aspect_ratio = self.context_aspect_ratio
context_scale = np.random.uniform(self.context_scale[0], self.context_scale[1])
y_student, y_teacher = self(x, target_aspect_ratio, target_scale, context_aspect_ratio, context_scale)
loss = self.criterion(y_student, y_teacher)
self.log('train_loss', loss)
return loss
def validation_step(self, batch, batch_idx):
x = batch
target_aspect_ratio = np.random.uniform(self.target_aspect_ratio[0], self.target_aspect_ratio[1])
target_scale = np.random.uniform(self.target_scale[0], self.target_scale[1])
context_aspect_ratio = self.context_aspect_ratio
context_scale = np.random.uniform(self.context_scale[0], self.context_scale[1])
y_student, y_teacher = self(x, target_aspect_ratio, target_scale, context_aspect_ratio, context_scale)
loss = self.criterion(y_student, y_teacher)
self.log('val_loss', loss)
return loss
def predict_step(self, batch, batch_idx, dataloader_idx):
target_aspect_ratio = np.random.uniform(self.target_aspect_ratio[0], self.target_aspect_ratio[1])
target_scale = np.random.uniform(self.target_scale[0], self.target_scale[1])
context_aspect_ratio = self.context_aspect_ratio
context_scale = 1
self.model.mode = "test"
return self(batch, target_aspect_ratio, target_scale, context_aspect_ratio, context_scale) #just get teacher embedding
def on_after_backward(self):
self.update_momentum(self.m)
self.m += (self.m_start_end[1] - self.m_start_end[0]) / self.trainer.estimated_stepping_batches
def configure_optimizers(self):
optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=self.lr,
total_steps=self.trainer.estimated_stepping_batches,
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": "step",
},
}
if __name__ == '__main__':
dataset = D2VDataModule(dataset_path='data')
model = IJEPA(img_size=224, patch_size=16, in_chans=3, embed_dim=64, enc_heads=8, enc_depth=8, decoder_depth=6, lr=1e-3)
lr_monitor = LearningRateMonitor(logging_interval="step")
model_summary = ModelSummary(max_depth=2)
trainer = pl.Trainer(
accelerator='gpu',
devices=1,
precision=16,
max_epochs=10,
callbacks=[lr_monitor, model_summary],
gradient_clip_val=.1,
)
trainer.fit(model, dataset)