-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Curya
authored and
Curya
committed
Mar 20, 2022
1 parent
463043d
commit 6396fe4
Showing
202 changed files
with
18,818 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
import os | ||
import random | ||
import numpy as np | ||
import torch | ||
import torch.utils.data as data | ||
import lib.utils as utils | ||
import pickle | ||
|
||
import json | ||
import cv2 | ||
from PIL import Image | ||
|
||
# 图像读取预处理单元 | ||
from torchvision import transforms | ||
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | ||
from timm.data.transforms import _pil_interp | ||
|
||
class CocoDataset(data.Dataset): | ||
def __init__( | ||
self, | ||
image_ids_path, | ||
input_seq, | ||
target_seq, | ||
gv_feat_path, | ||
att_feats_folder, | ||
seq_per_img, | ||
max_feat_num | ||
): | ||
self.max_feat_num = max_feat_num | ||
self.seq_per_img = seq_per_img | ||
# self.image_ids = utils.load_lines(image_ids_path) | ||
# 此处image_ids_path为ids2path的映射dict | ||
with open(image_ids_path, 'r') as f: | ||
self.ids2path = json.load(f) # dict {image_id: image_path} | ||
self.image_ids = list(self.ids2path.keys()) # list of str | ||
|
||
self.att_feats_folder = att_feats_folder if len(att_feats_folder) > 0 else None | ||
self.gv_feat = pickle.load(open(gv_feat_path, 'rb'), encoding='bytes') if len(gv_feat_path) > 0 else None | ||
|
||
# 构建图像预处理单元 | ||
self.transform = transforms.Compose([ | ||
transforms.Resize((384, 384), interpolation=_pil_interp('bicubic')), | ||
transforms.ToTensor(), | ||
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)] | ||
) | ||
|
||
if input_seq is not None and target_seq is not None: | ||
self.input_seq = pickle.load(open(input_seq, 'rb'), encoding='bytes') | ||
self.target_seq = pickle.load(open(target_seq, 'rb'), encoding='bytes') | ||
self.seq_len = len(self.input_seq[self.image_ids[0]][0,:]) | ||
else: | ||
self.seq_len = -1 | ||
self.input_seq = None | ||
self.target_seq = None | ||
|
||
def set_seq_per_img(self, seq_per_img): | ||
self.seq_per_img = seq_per_img | ||
|
||
def __len__(self): | ||
return len(self.image_ids) | ||
|
||
def __getitem__(self, index): | ||
image_id = self.image_ids[index] | ||
image_path = self.ids2path[image_id] | ||
indices = np.array([index]).astype('int') | ||
|
||
if self.gv_feat is not None: | ||
gv_feat = self.gv_feat[image_id] | ||
gv_feat = np.array(gv_feat).astype('float32') | ||
else: | ||
gv_feat = np.zeros((1,1)) | ||
|
||
# 此处att_feats_folder为coco数据集源图像保存路径,而非预训练特征保存路径 | ||
if self.att_feats_folder is not None: | ||
# att_feats = np.load(os.path.join(self.att_feats_folder, str(image_id) + '.npz'))['feat'] | ||
# att_feats = np.array(att_feats).astype('float32') | ||
# 读取图像,并进行预处理 | ||
image_path = self.ids2path[image_id] | ||
img = cv2.imread(os.path.join(self.att_feats_folder, image_path)) | ||
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) | ||
att_feats = self.transform(img) # [3, 384, 384],图像 | ||
else: | ||
# att_feats = np.zeros((1,1)) | ||
att_feats = torch.zeros(1, 1) | ||
|
||
if self.max_feat_num > 0 and att_feats.shape[0] > self.max_feat_num: | ||
att_feats = att_feats[:self.max_feat_num, :] | ||
|
||
if self.seq_len < 0: | ||
return indices, gv_feat, att_feats | ||
|
||
input_seq = np.zeros((self.seq_per_img, self.seq_len), dtype='int') | ||
target_seq = np.zeros((self.seq_per_img, self.seq_len), dtype='int') | ||
|
||
n = len(self.input_seq[image_id]) | ||
if n >= self.seq_per_img: | ||
sid = 0 | ||
ixs = random.sample(range(n), self.seq_per_img) | ||
else: | ||
sid = n | ||
ixs = random.sample(range(n), self.seq_per_img - n) | ||
input_seq[0:n, :] = self.input_seq[image_id] | ||
target_seq[0:n, :] = self.target_seq[image_id] | ||
|
||
for i, ix in enumerate(ixs): | ||
input_seq[sid + i] = self.input_seq[image_id][ix,:] | ||
target_seq[sid + i] = self.target_seq[image_id][ix,:] | ||
return indices, input_seq, target_seq, gv_feat, att_feats |
107 changes: 107 additions & 0 deletions
107
datasets/.ipynb_checkpoints/coco_dataset_e2e-checkpoint.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import os | ||
import random | ||
import numpy as np | ||
import torch | ||
import torch.utils.data as data | ||
import lib.utils as utils | ||
import pickle | ||
|
||
import cv2 | ||
from PIL import Image | ||
|
||
# 图像读取预处理单元 | ||
from torchvision import transforms | ||
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | ||
from timm.data.transforms import _pil_interp | ||
|
||
class CocoDataset(data.Dataset): | ||
def __init__( | ||
self, | ||
image_ids_path, | ||
input_seq, | ||
target_seq, | ||
gv_feat_path, | ||
att_feats_folder, | ||
seq_per_img, | ||
max_feat_num | ||
): | ||
self.max_feat_num = max_feat_num | ||
self.seq_per_img = seq_per_img | ||
# self.image_ids = utils.load_lines(image_ids_path) | ||
# 此处image_ids_path为ids2path的映射dict | ||
with open(image_ids_path, 'r') as f: | ||
self.ids2path = json.load(f) # dict {image_id: image_path} | ||
self.image_ids = self.ids2path.keys() # list of str | ||
|
||
self.att_feats_folder = att_feats_folder if len(att_feats_folder) > 0 else None | ||
self.gv_feat = pickle.load(open(gv_feat_path, 'rb'), encoding='bytes') if len(gv_feat_path) > 0 else None | ||
|
||
# 构建图像预处理单元 | ||
self.transform = transforms.Compose([ | ||
transforms.Resize((384, 384), interpolation=_pil_interp('bicubic')), | ||
transforms.ToTensor(), | ||
transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)] | ||
) | ||
|
||
if input_seq is not None and target_seq is not None: | ||
self.input_seq = pickle.load(open(input_seq, 'rb'), encoding='bytes') | ||
self.target_seq = pickle.load(open(target_seq, 'rb'), encoding='bytes') | ||
self.seq_len = len(self.input_seq[self.image_ids[0]][0,:]) | ||
else: | ||
self.seq_len = -1 | ||
self.input_seq = None | ||
self.target_seq = None | ||
|
||
def set_seq_per_img(self, seq_per_img): | ||
self.seq_per_img = seq_per_img | ||
|
||
def __len__(self): | ||
return len(self.image_ids) | ||
|
||
def __getitem__(self, index): | ||
image_id = self.image_ids[index] | ||
image_path = self.ids2path[image_id] | ||
indices = np.array([index]).astype('int') | ||
|
||
if self.gv_feat is not None: | ||
gv_feat = self.gv_feat[image_id] | ||
gv_feat = np.array(gv_feat).astype('float32') | ||
else: | ||
gv_feat = np.zeros((1,1)) | ||
|
||
# 此处att_feats_folder为coco数据集源图像保存路径,而非预训练特征保存路径 | ||
if self.att_feats_folder is not None: | ||
# att_feats = np.load(os.path.join(self.att_feats_folder, str(image_id) + '.npz'))['feat'] | ||
# att_feats = np.array(att_feats).astype('float32') | ||
# 读取图像,并进行预处理 | ||
image_path = self.ids2path[image_id] | ||
img = cv2.imread(os.path.join(self.att_feats_folder, image_path)) | ||
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) | ||
att_feats = self.transform(img) # [3, 384, 384],图像 | ||
else: | ||
# att_feats = np.zeros((1,1)) | ||
att_feats = torch.zeros(1, 1) | ||
|
||
if self.max_feat_num > 0 and att_feats.shape[0] > self.max_feat_num: | ||
att_feats = att_feats[:self.max_feat_num, :] | ||
|
||
if self.seq_len < 0: | ||
return indices, gv_feat, att_feats | ||
|
||
input_seq = np.zeros((self.seq_per_img, self.seq_len), dtype='int') | ||
target_seq = np.zeros((self.seq_per_img, self.seq_len), dtype='int') | ||
|
||
n = len(self.input_seq[image_id]) | ||
if n >= self.seq_per_img: | ||
sid = 0 | ||
ixs = random.sample(range(n), self.seq_per_img) | ||
else: | ||
sid = n | ||
ixs = random.sample(range(n), self.seq_per_img - n) | ||
input_seq[0:n, :] = self.input_seq[image_id] | ||
target_seq[0:n, :] = self.target_seq[image_id] | ||
|
||
for i, ix in enumerate(ixs): | ||
input_seq[sid + i] = self.input_seq[image_id][ix,:] | ||
target_seq[sid + i] = self.target_seq[image_id][ix,:] | ||
return indices, input_seq, target_seq, gv_feat, att_feats |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import os | ||
import torch | ||
from torchvision import transforms | ||
from lib.config import cfg | ||
from datasets.coco_dataset import CocoDataset | ||
import samplers.distributed | ||
import numpy as np | ||
|
||
def sample_collate(batch): | ||
indices, input_seq, target_seq, gv_feat, att_feats = zip(*batch) | ||
|
||
indices = np.stack(indices, axis=0).reshape(-1) | ||
input_seq = torch.cat([torch.from_numpy(b) for b in input_seq], 0) | ||
target_seq = torch.cat([torch.from_numpy(b) for b in target_seq], 0) | ||
gv_feat = torch.cat([torch.from_numpy(b) for b in gv_feat], 0) | ||
|
||
""" | ||
# 读取图像的预训练特征时,大小为[L, D],其中L的长度可能不一(如目标特征) | ||
# 因此需要进行特征数量判断,并生成特征掩码 att_mask | ||
atts_num = [x.shape[0] for x in att_feats] | ||
max_att_num = np.max(atts_num) | ||
feat_arr = [] | ||
mask_arr = [] | ||
for i, num in enumerate(atts_num): | ||
tmp_feat = np.zeros((1, max_att_num, att_feats[i].shape[1]), dtype=np.float32) | ||
tmp_feat[:, 0:att_feats[i].shape[0], :] = att_feats[i] | ||
feat_arr.append(torch.from_numpy(tmp_feat)) | ||
tmp_mask = np.zeros((1, max_att_num), dtype=np.float32) | ||
tmp_mask[:, 0:num] = 1 | ||
mask_arr.append(torch.from_numpy(tmp_mask)) | ||
att_feats = torch.cat(feat_arr, 0) | ||
att_mask = torch.cat(mask_arr, 0) | ||
""" | ||
# 图像特征,无需与预训练特征一样进行特征数量判断,直接合并即可 | ||
# att_mask为最终grid特征大小,实际上grid特征无需att_mask亦可 | ||
att_feats = torch.stack(att_feats, 0) # [B, 3, 384, 384] | ||
att_mask = torch.ones(att_feats.size()[0], 12*12) | ||
|
||
return indices, input_seq, target_seq, gv_feat, att_feats, att_mask | ||
|
||
def sample_collate_val(batch): | ||
indices, gv_feat, att_feats = zip(*batch) | ||
|
||
indices = np.stack(indices, axis=0).reshape(-1) | ||
gv_feat = torch.cat([torch.from_numpy(b) for b in gv_feat], 0) | ||
|
||
""" | ||
# 读取图像的预训练特征时,大小为[L, D],其中L的长度可能不一(如目标特征) | ||
# 因此需要进行特征数量判断,并生成特征掩码 att_mask | ||
atts_num = [x.shape[0] for x in att_feats] | ||
max_att_num = np.max(atts_num) | ||
feat_arr = [] | ||
mask_arr = [] | ||
for i, num in enumerate(atts_num): | ||
tmp_feat = np.zeros((1, max_att_num, att_feats[i].shape[1]), dtype=np.float32) | ||
tmp_feat[:, 0:att_feats[i].shape[0], :] = att_feats[i] | ||
feat_arr.append(torch.from_numpy(tmp_feat)) | ||
tmp_mask = np.zeros((1, max_att_num), dtype=np.float32) | ||
tmp_mask[:, 0:num] = 1 | ||
mask_arr.append(torch.from_numpy(tmp_mask)) | ||
att_feats = torch.cat(feat_arr, 0) | ||
att_mask = torch.cat(mask_arr, 0) | ||
""" | ||
# 图像特征,无需与预训练特征一样进行特征数量判断,直接合并即可 | ||
# att_mask为最终grid特征大小,实际上grid特征无需att_mask亦可 | ||
att_feats = torch.stack(att_feats, 0) # [B, 3, 384, 384] | ||
att_mask = torch.ones(att_feats.size()[0], 12*12) | ||
|
||
return indices, gv_feat, att_feats, att_mask | ||
|
||
|
||
def load_train(distributed, epoch, coco_set): | ||
sampler = samplers.distributed.DistributedSampler(coco_set, epoch=epoch) \ | ||
if distributed else None | ||
shuffle = cfg.DATA_LOADER.SHUFFLE if sampler is None else False | ||
|
||
loader = torch.utils.data.DataLoader( | ||
coco_set, | ||
batch_size = cfg.TRAIN.BATCH_SIZE, | ||
shuffle = shuffle, | ||
num_workers = cfg.DATA_LOADER.NUM_WORKERS, | ||
drop_last = cfg.DATA_LOADER.DROP_LAST, | ||
pin_memory = cfg.DATA_LOADER.PIN_MEMORY, | ||
sampler = sampler, | ||
collate_fn = sample_collate | ||
) | ||
return loader | ||
|
||
def load_val(image_ids_path, gv_feat_path, att_feats_folder): | ||
coco_set = CocoDataset( | ||
image_ids_path = image_ids_path, | ||
input_seq = None, | ||
target_seq = None, | ||
gv_feat_path = gv_feat_path, | ||
att_feats_folder = att_feats_folder, | ||
seq_per_img = 1, | ||
max_feat_num = cfg.DATA_LOADER.MAX_FEAT | ||
) | ||
|
||
loader = torch.utils.data.DataLoader( | ||
coco_set, | ||
batch_size = cfg.TEST.BATCH_SIZE, | ||
shuffle = False, | ||
num_workers = cfg.DATA_LOADER.NUM_WORKERS, | ||
drop_last = False, | ||
pin_memory = cfg.DATA_LOADER.PIN_MEMORY, | ||
collate_fn = sample_collate_val | ||
) | ||
return loader |
Oops, something went wrong.