From bdb3fc2a1458bf82a56a227d5dca82f9dba60226 Mon Sep 17 00:00:00 2001 From: Yunseong Lee Date: Thu, 18 Jul 2019 19:59:41 +0900 Subject: [PATCH 1/3] Port JsFusion to RnB --- config/jsfusion-whole.json | 17 ++ models/jsfusion/__init__.py | 0 models/jsfusion/attention.py | 27 ++ models/jsfusion/data_util.py | 211 +++++++++++++++ models/jsfusion/model.py | 121 +++++++++ models/jsfusion/module.py | 507 +++++++++++++++++++++++++++++++++++ models/jsfusion/sampler.py | 15 ++ 7 files changed, 898 insertions(+) create mode 100644 config/jsfusion-whole.json create mode 100644 models/jsfusion/__init__.py create mode 100644 models/jsfusion/attention.py create mode 100644 models/jsfusion/data_util.py create mode 100644 models/jsfusion/model.py create mode 100644 models/jsfusion/module.py create mode 100644 models/jsfusion/sampler.py diff --git a/config/jsfusion-whole.json b/config/jsfusion-whole.json new file mode 100644 index 0000000..6e4c5c3 --- /dev/null +++ b/config/jsfusion-whole.json @@ -0,0 +1,17 @@ +{ + "video_path_iterator": "models.jsfusion.model.JsFusionVideoPathIterator", + "pipeline": [ + { + "model": "models.jsfusion.model.JsFusionLoader", + "gpus": [0] + }, + { + "model": "models.jsfusion.model.ResNetRunner", + "gpus": [0] + }, + { + "model": "models.jsfusion.model.MCModelRunner", + "gpus": [0] + } + ] +} diff --git a/models/jsfusion/__init__.py b/models/jsfusion/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/models/jsfusion/attention.py b/models/jsfusion/attention.py new file mode 100644 index 0000000..405a395 --- /dev/null +++ b/models/jsfusion/attention.py @@ -0,0 +1,27 @@ +import torch +import math + +MIN_TIMESCALE=1.0 +MAX_TIMESCALE=1.0e4 + +def add_timing_signal_nd(num_frames, video_channels): + shape = [1, num_frames, video_channels] + num_dims = len(shape) - 2 + channels = shape[-1] + + position = torch.tensor(range(num_frames), dtype=torch.float32) + position = torch.unsqueeze(position, dim=1) + + num_timescales = channels // (num_dims * 2) + log_timescale_increment = math.log(MAX_TIMESCALE / MIN_TIMESCALE) / (num_timescales - 1) + inv_timescales = [] + for i in range(num_timescales): + inv_timescales.append(1.0 * math.exp(-float(i) * log_timescale_increment)) + inv_timescales = torch.tensor(inv_timescales, dtype=torch.float32) + inv_timescales = torch.unsqueeze(inv_timescales, dim=0) + + scaled_time = position.matmul(inv_timescales) + signal = torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1) + signal = torch.unsqueeze(signal, 0) + + return signal diff --git a/models/jsfusion/data_util.py b/models/jsfusion/data_util.py new file mode 100644 index 0000000..323c0ed --- /dev/null +++ b/models/jsfusion/data_util.py @@ -0,0 +1,211 @@ +"""Utility class used in JSFusion model, copied from the original author's code +https://github.com/yj-yu/lsmdc/blob/master/videocap/datasets/data_util.py +""" +import time +import numpy as np +import re + + +def clean_str(string, downcase=True): + """Tokenization/string cleaning for strings. + + Taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py + """ + string = re.sub(r"[^A-Za-z0-9(),!?\'\`(_____)]", " ", string) + string = re.sub(r"\'s", " \'s", string) + string = re.sub(r"\'ve", " \'ve", string) + string = re.sub(r"n\'t", " n\'t", string) + string = re.sub(r"\'re", " \'re", string) + string = re.sub(r"\'d", " \'d", string) + string = re.sub(r"\'ll", " \'ll", string) + string = re.sub(r",", " , ", string) + string = re.sub(r"!", " ! ", string) + string = re.sub(r"\(", " \( ", string) + string = re.sub(r"\)", " \) ", string) + string = re.sub(r"\?", " \? ", string) + string = re.sub(r"\s{2,}", " ", string) + return string.strip().lower() if downcase else string.strip() + +def recover_word(string): + string = re.sub(r" \'s", "\'s", string) + string = re.sub(r" ,", ",", string) + return string + +def clean_blank(blank_sent): + """Tokenizes and changes _____ to + would be Answer position in FIB work. + """ + clean_sent = clean_str(blank_sent).split() + return ['' if x == '_____' else x for x in clean_sent] + + +def clean_root(string): + """Removes unexpected character in root. + """ + return string + + +def pad_sequences(sequences, pad_token="[PAD]", pad_location="LEFT", max_length=None): + """Pads all sequences to the same length. The length is defined by the longest sequence. + Returns padded sequences. + """ + if not max_length: + max_length = max(len(x) for x in sequences) + + result = [] + for i in range(len(sequences)): + sentence = sequences[i] + num_padding = max_length - len(sentence) + if num_padding == 0: + new_sentence = sentence + elif num_padding < 0: + new_sentence = sentence[:num_padding] + elif pad_location == "RIGHT": + new_sentence = sentence + [pad_token] * num_padding + elif pad_location == "LEFT": + new_sentence = [pad_token] * num_padding + sentence + else: + print("Invalid pad_location. Specify LEFT or RIGHT.") + result.append(new_sentence) + return result + + +def convert_sent_to_index(sentence, word_to_index): + """Converts sentence consisting of string to indexed sentence. + """ + return [word_to_index[word] if word in word_to_index.keys() else 0 for word in sentence] + + +def batch_iter(data, batch_size, seed=None, fill=True): + """Generates a batch iterator for a dataset. + """ + random = np.random.RandomState(seed) + data_length = len(data) + num_batches = int(data_length / batch_size) + if data_length % batch_size != 0: + num_batches += 1 + + # Shuffle the data at each epoch + shuffle_indices = random.permutation(np.arange(data_length)) + for batch_num in range(num_batches): + start_index = batch_num * batch_size + end_index = min((batch_num + 1) * batch_size, data_length) + selected_indices = shuffle_indices[start_index:end_index] + # If we don't have enough data left for a whole batch, fill it randomly + if fill and end_index >= data_length: + num_missing = batch_size - len(selected_indices) + selected_indices = np.concatenate([selected_indices, random.randint(0, data_length, num_missing)]) + yield [data[i] for i in selected_indices] + + +def fsr_iter(fsr_data, batch_size, random_seed=42, fill=True): + """fsr_data: one of LSMDCData.build_data(), [[video_features], [sentences], [roots]] + return per iter: [[feature]*batch_size, [sentences]*batch_size, [roots]*batch] + + Usage: + train_data, val_data, test_data = LSMDCData.build_data() + for features, sentences, roots in fsr_iter(train_data, 20, 10): + feed_dict = {model.video_feature : features, + model.sentences : sentences, + model.roots : roots} + """ + + train_iter = batch_iter(list(zip(*fsr_data)), batch_size, fill=fill, seed=random_seed) + return map(lambda batch: zip(*batch), train_iter) + + +def preprocess_sents(descriptions, word_to_index, max_length): + descriptions = [clean_str(sent).split() for sent in descriptions] + # Add padding on the right to each sentence in order to keep the same lengths. + descriptions = pad_sequences(descriptions, max_length=max_length) + # Convert sentences from a list of string to the list of indices (int) + descriptions = [convert_sent_to_index(sent, word_to_index) for sent in descriptions] + + return descriptions + # remove punctuation mark and special chars from root. + + +def preprocess_roots(roots, word_to_index): + roots = [clean_root(root) for root in roots] + # convert string to int index. + roots = [word_to_index[root] if root in word_to_index.keys() else 0 for root in roots] + + return roots + + +def pad_video(video_feature, dimension, padded_feature=None): + """Fills pad to video to have same length. + Pad in Left. + video = [pad,..., pad, frm1, frm2, ..., frmN] + """ + if padded_feature is None: + padded_feature = np.zeros(dimension, dtype=np.float32) + max_length = dimension[0] + current_length = video_feature.shape[0] + num_padding = max_length - current_length + if num_padding == 0: + padded_feature[:] = video_feature + elif num_padding < 0: + steps = np.linspace(0, current_length, num=max_length, endpoint=False, dtype=np.int32) + padded_feature[:] = video_feature[steps] + else: + # about 0.7 sec + padded_feature[num_padding:] = video_feature + + return padded_feature + +def repeat_pad_video(video_feature, dimension): + padded_feature = np.zeros(dimension, dtype= np.float) + max_length = dimension[0] + current_length = video_feature.shape[0] + + if current_length == max_length: + padded_feature[:] = video_feature + + elif current_length < max_length: + tile_num = int(max_length / current_length) + to_tile = np.ones(len(dimension), dtype=np.int32) + to_tile[0] = tile_num + remainder = max_length % current_length + tiled_vid = np.tile(video_feature, to_tile) + if remainder > 0: + padded_feature[0:remainder] = video_feature[-remainder:] + padded_feature[remainder:] = tiled_vid + + else: + steps = np.linspace(0, current_length, num=max_length, endpoint=False, dtype=np.int32) + padded_feature[:] = video_feature[steps] + return padded_feature + +def stretch_pad_video(video_feature, dimension): + padded_feature = np.zeros(dimension, dtype= np.float) + max_length = dimension[0] + current_length = video_feature.shape[0] + + if current_length == max_length: + padded_feature[:] = video_feature + elif current_length < max_length: + repeat_num = int((max_length-1) / current_length)+1 + tiled_vid = np.repeat(video_feature, repeat_num,0) + steps = np.linspace(0, repeat_num*current_length, num=max_length, endpoint=False, dtype=np.int32) + padded_feature[:] = tiled_vid[steps] + else: + steps = np.linspace(0, current_length, num=max_length, endpoint=False, dtype=np.int32) + padded_feature[:] = video_feature[steps] + return padded_feature + + +def fill_mask(max_length, current_length, zero_location='LEFT'): + num_padding = max_length - current_length + if num_padding <= 0: + mask = np.ones(max_length) + elif zero_location == 'LEFT': + mask = np.ones(max_length) + for i in range(num_padding): + mask[i] = 0 + elif zero_location == 'RIGHT': + mask = np.zeros(max_length) + for i in range(current_length): + mask[i] = 1 + + return mask diff --git a/models/jsfusion/model.py b/models/jsfusion/model.py new file mode 100644 index 0000000..50eee1e --- /dev/null +++ b/models/jsfusion/model.py @@ -0,0 +1,121 @@ +from models.jsfusion.module import ResNetFeatureExtractor +from models.jsfusion.module import MCModel +from models.jsfusion.sampler import FixedSampler + +from runner_model import RunnerModel +from video_path_provider import VideoPathIterator +from itertools import cycle +from torchvision import transforms +import torch +import nvvl +import os + +class JsFusionVideoPathIterator(VideoPathIterator): + def __init__(self): + super(JsFusionVideoPathIterator, self).__init__() + + videos = [] + video_dir = os.path.join(os.environ['LSMDC_PATH'], 'mp4s') + for video in os.listdir(video_dir): + videos.append(os.path.join(video_dir, video)) + + if len(videos) <= 0: + raise Exception('No video available.') + + self.videos_iter = cycle(videos) + + def __iter__(self): + return self.videos_iter + +class JsFusionLoader(RunnerModel): + """Impl of loading video frames using NVVL, for the R(2+1)D model.""" + def __init__(self, device): + self.loader = nvvl.RnBLoader(width=224, height=224, + consecutive_frames=1, device_id=device.index, + sampler=FixedSampler(num_frames=40)) + + samples = [ + os.path.join(os.environ['LSMDC_PATH'], 'mp4s/1004_Juno_00.00.32.849-00.00.35.458.mp4'), + os.path.join(os.environ['LSMDC_PATH'], 'mp4s/1004_Juno_00.00.35.642-00.00.45.231.mp4'), + os.path.join(os.environ['LSMDC_PATH'], 'mp4s/1004_Juno_00.00.49.801-00.00.59.450.mp4')] + + # warm up GPU with a few inferences + for sample in samples: + self.loader.loadfile(sample) + for frames in self.loader: + pass + self.loader.flush() + + def __call__(self, input): + _, file_path = input + self.loader.loadfile(file_path) + for frames in self.loader: + pass + self.loader.flush() + + + # frames: (40, 3, 1, 224, 224) + frames = frames.float() + frames = frames.permute(0, 2, 1, 3, 4) + + transform = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + frames_tmp = [] + for frame in frames: + frame = torch.squeeze(frame) + frame /= 255 + frame = transform(frame) + frames_tmp.append(frame) + frames = torch.stack(frames_tmp) + # frames: (40, 3, 224, 224) + + filename = os.path.basename(file_path) + out = (frames, filename) + return out + + def __del__(self): + self.loader.close() + + def input_shape(self): + return None + + @staticmethod + def output_shape(): + return ((40, 3, 224, 224),) + + +class ResNetRunner(RunnerModel): + def __init__(self, device, num_frames = 40): + super(ResNetRunner, self).__init__(device) + self.model = ResNetFeatureExtractor(num_frames).to(device) + self.model.float() + self.model.eval() + + def input_shape(self): + return ((40, 3, 224, 224),) + + @staticmethod + def output_shape(): + return ((1, 40, 2048),) + + def __call__(self, input): + return self.model(input) + + +class MCModelRunner(RunnerModel): + def __init__(self, device, num_frames = 40): + super(MCModelRunner, self).__init__(device) + self.model = MCModel(device).to(device) + self.model.float() + self.model.eval() + + def input_shape(self): + return ((1, 40, 2048),) + + def __call__(self, input): + return self.model(input) + + @staticmethod + def output_shape(): + return ((1,),) + diff --git a/models/jsfusion/module.py b/models/jsfusion/module.py new file mode 100644 index 0000000..d6cb418 --- /dev/null +++ b/models/jsfusion/module.py @@ -0,0 +1,507 @@ +"""A PyTorch implementation of the JSFusion model. +See https://github.com/yj-yu/lsmdc for the original +TensorFlow implementation from the authors. +A Joint Sequence Fusion Model for Video Question Answering and Retrieval, +Yu et al., ECCV 2018. +""" +import torch +import torch.nn.functional as F +import numpy as np + +import time +import os + +from models.jsfusion.attention import add_timing_signal_nd +import hickle as hkl +from torchvision import models +import math + +class ResNetFeatureExtractor(torch.nn.Module): + def __init__(self, num_frames = 40): + super(ResNetFeatureExtractor, self).__init__() + self.resnet = models.resnet152(pretrained=True) + self.num_frames = num_frames + + module_list = list(self.resnet.children()) + self.pool = torch.nn.Sequential(*module_list[:-1]) + self.in_features = module_list[-1].in_features # 2048 for ResNet152 + + def forward(self, tensors): + (frames,), filename = tensors + resnet_output = self.pool(frames) + resnet_output = resnet_output.view(resnet_output.shape[0], resnet_output.shape[1]) + # TODO handle the case when resnet_output.shape[0] < num_frames (fill zeros) + resnet_output = resnet_output[:self.num_frames, :] + + return ((resnet_output,), filename) + + +class MCModel(torch.nn.Module): + + def __init__(self, device, dropout_prob = 0.5, video_channels = 2048, num_frames = 40): + super(MCModel, self).__init__() + + self.device = device + + self.num_frames = num_frames + self.register_buffer('mask', torch.ones((self.num_frames), dtype=torch.float32)) + self.register_buffer('one', torch.tensor(1, dtype=torch.int32)) + self.register_buffer('signal', add_timing_signal_nd(self.num_frames, video_channels)) + + self.dropout = torch.nn.Dropout(p=dropout_prob) + self.conv1 = torch.nn.Conv2d(2048, 2048, [3, 1], padding=(1, 0)) + self.relu1 = torch.nn.ReLU() + self.bn1 = torch.nn.BatchNorm2d(2048, eps=0.001, momentum=0.001) + self.conv2 = torch.nn.Conv2d(2048, 2048, [3, 1], padding=(1, 0)) + self.relu2 = torch.nn.ReLU() + self.bn2 = torch.nn.BatchNorm2d(2048, eps=0.001, momentum=0.001) + self.conv3 = torch.nn.Conv2d(2048, 2048, [3, 1], padding=(1, 0)) + self.relu3 = torch.nn.ReLU() + self.bn3 = torch.nn.BatchNorm2d(2048, eps=0.001, momentum=0.001) + + self.sigmoid = torch.nn.Sigmoid() + + self.fc4 = torch.nn.Linear(1024+video_channels, 512) + self.tanh4 = torch.nn.Tanh() + self.bn4 = torch.nn.BatchNorm1d(512, eps=0.001, momentum=0.001) + + embedding_matrix = hkl.load(os.path.join(os.environ['LSMDC_PATH'], 'hkls/common_word_matrix_py3.hkl')) + embedding_matrix = torch.tensor(embedding_matrix, dtype=torch.float32) + self.register_buffer('embedding_matrix', embedding_matrix) + + + self.lstm = torch.nn.LSTM(300, 512, 1, batch_first=True, dropout=dropout_prob, bidirectional=True) + self.fc5 = torch.nn.Linear(512*2, 512) + self.tanh5 = torch.nn.Tanh() + self.bn5 = torch.nn.BatchNorm1d(512, eps=0.001, momentum=0.001) + + + self.fusion_fc1 = torch.nn.Linear(512, 512) + self.fusion_tanh1 = torch.nn.Tanh() + self.fusion_bn1 = torch.nn.BatchNorm2d(512, eps=0.001, momentum=0.001) + self.fusion_gate2 = torch.nn.Linear(512, 1) + self.fusion_sigmoid2 = torch.nn.Sigmoid() + self.fusion_bn2 = torch.nn.BatchNorm2d(1, eps=0.001, momentum=0.001) + + self.fusion_fc3 = torch.nn.Linear(512, 512) + self.fusion_tanh3 = torch.nn.Tanh() + self.fusion_bn3 = torch.nn.BatchNorm2d(512, eps=0.001, momentum=0.001) + self.fusion_fc4 = torch.nn.Linear(512, 512) + self.fusion_tanh4 = torch.nn.Tanh() + self.fusion_bn4 = torch.nn.BatchNorm2d(512, eps=0.001, momentum=0.001) + + + self.fusion_next_conv1 = torch.nn.Conv2d(512, 256, [3, 3]) + self.fusion_next_tanh1 = torch.nn.Tanh() + self.fusion_next_convalp1 = torch.nn.Conv2d(512, 1, [3, 3]) + self.fusion_next_tanhalp1 = torch.nn.Tanh() + self.fusion_next_gate2 = torch.nn.Sigmoid() + + self.fusion_next_conv2 = torch.nn.Conv2d(256, 256, [3, 3]) + self.fusion_next_tanh2 = torch.nn.Tanh() + self.fusion_next_convalp2 = torch.nn.Conv2d(256, 1, [3, 3]) + self.fusion_next_tanhalp2 = torch.nn.Tanh() + self.fusion_next_gate3 = torch.nn.Sigmoid() + + self.fusion_next_conv3 = torch.nn.Conv2d(256, 256, [3, 3], stride=(2, 2)) + self.fusion_next_tanh3 = torch.nn.Tanh() + self.fusion_next_convalp3 = torch.nn.Conv2d(256, 1, [3, 3], stride=(2, 2)) + self.fusion_next_tanhalp3 = torch.nn.Tanh() + self.fusion_next_gate4 = torch.nn.Sigmoid() + + + self.final_fc1 = torch.nn.Linear(256, 256) + self.final_tanh1 = torch.nn.Tanh() + self.final_bn1 = torch.nn.BatchNorm1d(256, eps=0.001, momentum=0.001) + self.final_fc2 = torch.nn.Linear(256, 256) + self.final_tanh2 = torch.nn.Tanh() + self.final_bn2 = torch.nn.BatchNorm1d(256, eps=0.001, momentum=0.001) + self.final_fc3 = torch.nn.Linear(256, 128) + self.final_tanh3 = torch.nn.Tanh() + self.final_bn3 = torch.nn.BatchNorm1d(128, eps=0.001, momentum=0.001) + self.final_fc4 = torch.nn.Linear(128, 1) + self.final_bn4 = torch.nn.BatchNorm1d(1, eps=0.001, momentum=0.001) + self.word2idx = hkl.load(os.path.join(os.environ['LSMDC_PATH'], 'hkls/common_word_to_index_py3.hkl')) + + + def video_embeddings(self, video, mask): + # BxLxC + embedded_feat_tmp = video + self.signal + embedded_feat = embedded_feat_tmp * torch.unsqueeze(mask, 2) + embedded_feat_drop = self.dropout(embedded_feat) + + # BxCxL + video_emb = embedded_feat_drop.permute(0, 2, 1) + + # BxCxLx1 + video_emb_unsqueezed = torch.unsqueeze(video_emb, 3) + + conv1 = self.conv1(video_emb_unsqueezed) + relu1 = self.relu1(conv1) + bn1 = self.bn1(relu1) + + conv2 = self.conv2(bn1) + relu2 = self.relu2(conv2) + bn2 = self.bn2(relu2) + + conv3 = self.conv3(bn2) + relu3 = self.relu3(conv3) + bn3 = self.bn3(relu3) + + # Bx2048xL + outputs = torch.squeeze(bn3, 3) + input_pass = outputs[:, 0:1024, :] + input_gate = outputs[:, 1024:, :] + input_gate = self.sigmoid(input_gate) + outputs = input_pass * input_gate + + # Bx(C+2048)xL + outputs = torch.cat([outputs, video_emb], dim=1) + + # BxLx(C+2048) + outputs = outputs.permute(0, 2, 1) + + # BxLx512 + fc4 = self.fc4(outputs) + tanh4 = self.tanh4(fc4) + + # Bx512xL + tanh4 = tanh4.permute(0, 2, 1) + bn4 = self.bn4(tanh4) + + masked_outputs = bn4 * torch.unsqueeze(mask, 1) + return masked_outputs + + + def word_embeddings(self, captions, caption_masks): + # 5BxL + captions = captions.view(-1, captions.shape[-1]) + # 5BxLxH + seq_embeddings = self.embeddings(captions) + + # 5BxLx1 + caption_masks = caption_masks.view(-1, caption_masks.shape[-1], 1) + + # 5BxLxH + embedded_sentence = seq_embeddings * caption_masks + + # 5BxLx1024 + outputs, _ = self.lstm(embedded_sentence) + + # 5BxLx512 + fc5 = self.fc5(outputs) + tanh5 = self.tanh5(fc5) + + # 5Bx512xL + tanh5 = tanh5.permute(0, 2, 1) + bn5 = self.bn5(tanh5) + + rnn_output = bn5 * caption_masks.view(caption_masks.shape[0], 1, caption_masks.shape[1]) + return rnn_output + + + def fusion(self, v, w, mask, caption_masks): + # 5Bx512xL + v = v.repeat(5, 1, 1) + + # 5Bx512xLx1 + vv = torch.unsqueeze(v, 3) + + # 5Bx512x1xL + ww = torch.unsqueeze(w, 2) + + # 5Bx512xLxL + cnn_repr = vv * ww + + # 5BxLxLx512 + cnn_repr = cnn_repr.permute(0, 2, 3, 1) + + # 5BxLxLx512 + fc1 = self.fusion_fc1(cnn_repr) + tanh1 = self.fusion_tanh1(fc1) + tanh1 = tanh1.permute(0, 3, 1, 2) + bn1 = self.fusion_bn1(tanh1) + bn1 = bn1.permute(0, 2, 3, 1) + + # 5BxLxLx1 + gate2 = self.fusion_gate2(bn1) + sigmoid2 = self.fusion_sigmoid2(gate2) + sigmoid2 = sigmoid2.permute(0, 3, 1, 2) + bn2 = self.fusion_bn2(sigmoid2) + # 5Bx1xLxL + + # 5BxLxLx512 + fc3 = self.fusion_fc3(cnn_repr) + tanh3 = self.fusion_tanh3(fc3) + tanh3 = tanh3.permute(0, 3, 1, 2) + bn3 = self.fusion_bn3(tanh3) + bn3 = bn3.permute(0, 2, 3, 1) + + # 5BxLxLx512 + fc4 = self.fusion_fc4(bn3) + tanh4 = self.fusion_tanh4(fc4) + tanh4 = tanh4.permute(0, 3, 1, 2) + bn4 = self.fusion_bn4(tanh4) + # 5Bx512xLxL + + # 5Bx512xLxL + output1 = bn4 * bn2 + + # Bx1xLx1 + shape = mask.shape + mask = torch.reshape(mask, (shape[0], 1, shape[1], 1)) + + # 5Bx1xLx1 + mask = mask.repeat(5, 1, 1, 1) + + # 5Bx1x1xL + shape = caption_masks.shape + caption_masks = torch.reshape(caption_masks, (shape[0] * shape[1], 1, 1, shape[2])) + + # 5Bx512xLxL + output1 = output1 * mask * caption_masks + + return output1 + + + def fusion_next(self, output1, mask, caption_masks): + # 5BxL + caption_masks = torch.reshape(caption_masks, (-1, caption_masks.shape[-1])) + + cut_mask_list = [] + cut_caption_masks_list = [] + cut_mask = mask[:, :-2] + cut_mask[:, -1] = 1. + + # Bx(L-2) + cut_mask_list.append(cut_mask.repeat(5, 1)) + cut_caption_masks = caption_masks[:, 2:] + cut_caption_masks[:, 0] = 1. + # 5Bx(L-2) + cut_caption_masks_list.append(cut_caption_masks) + + cut_mask = cut_mask[:, :-2] + cut_mask[:, -1] = 1. + # Bx(L-4) + cut_mask_list.append(cut_mask.repeat(5, 1)) + cut_caption_masks = cut_caption_masks[:, 2:] + cut_caption_masks[:, 0] = 1. + # 5Bx(L-4) + cut_caption_masks_list.append(cut_caption_masks) + + + max_len = (mask.shape[1] - 5) // 2 + cut_mask_len = (torch.sum(cut_mask, 1, dtype=torch.int32) - 1) / 2 + cut_mask_len = torch.max(cut_mask_len, self.one) + cut_caption_masks_len = (torch.sum(cut_caption_masks, 1, dtype=torch.int32) - 1) / 2 + cut_caption_masks_len = torch.max(cut_caption_masks_len, self.one) + + + cut_mask_indices = [i for i in range(cut_mask.shape[1]) if i % 2 == 1 and i < cut_mask.shape[1] - 1] + cut_mask_indices = torch.tensor(cut_mask_indices) + cut_mask_indices = cut_mask_indices.to(device=self.device, non_blocking=True) + + # cut_mask = torch.tensor([([0]*(max_len - l) + [1]*l) for l in cut_mask_len.cpu().numpy()], dtype=torch.float32) + cut_mask = torch.index_select(cut_mask, 1, cut_mask_indices) + + cut_caption_masks_indices = [i for i in range(cut_caption_masks.shape[1]) if i % 2 == 1 and i > 1] + cut_caption_masks_indices = torch.tensor(cut_caption_masks_indices) + cut_caption_masks_indices = cut_caption_masks_indices.to(device=self.device, non_blocking=True) + + + # cut_caption_masks = torch.tensor([([1]*l + [0]*(max_len - l)) for l in cut_caption_masks_len.cpu().numpy()], dtype=torch.float32) + cut_caption_masks = torch.index_select(cut_caption_masks, 1, cut_caption_masks_indices) + + cut_mask_list.append(cut_mask.repeat(5, 1)) + cut_caption_masks_list.append(cut_caption_masks) + + + # 5Bx256x(L-2)x(L-2) + conv1 = self.fusion_next_conv1(output1) + tanh1 = self.fusion_next_tanh1(conv1) + + # 5Bx1x(L-2)x(L-2) + convalp1 = self.fusion_next_convalp1(output1) + tanhalp1 = self.fusion_next_tanhalp1(convalp1) + gate2 = self.fusion_next_gate2(tanhalp1) + + # 5Bx256x(L-2)x(L-2) + output2 = tanh1 * gate2 + + # (5B, L-2) + shape = cut_mask_list[0].shape + mask = torch.reshape(cut_mask_list[0], (shape[0], 1, shape[1], 1)) + # (5B, L-2) + shape = cut_caption_masks_list[0].shape + caption_masks = torch.reshape(cut_caption_masks_list[0], (shape[0], 1, 1, shape[1])) + + # 5Bx256x(L-2)x(L-2) + output2 = output2 * mask * caption_masks + + + # 5Bx256x(L-4)x(L-4) + conv2 = self.fusion_next_conv2(output2) + tanh2 = self.fusion_next_tanh2(conv2) + + # 5Bx1x(L-4)x(L-4) + convalp2 = self.fusion_next_convalp2(output2) + tanhalp2 = self.fusion_next_tanhalp2(convalp2) + gate3 = self.fusion_next_gate3(tanhalp2) + + # 5Bx256x(L-4)x(L-4) + output3 = tanh2 * gate3 + + # (5B, L-4) + shape = cut_mask_list[1].shape + mask = torch.reshape(cut_mask_list[1], (shape[0], 1, shape[1], 1)) + # (5B, L-4) + shape = cut_caption_masks_list[1].shape + caption_masks = torch.reshape(cut_caption_masks_list[1], (shape[0], 1, 1, shape[1])) + + # 5Bx256x(L-4)x(L-4) + output3 = output3 * mask * caption_masks + + + # 5Bx256xhalfxhalf + conv3 = self.fusion_next_conv3(output3) + tanh3 = self.fusion_next_tanh3(conv3) + + # 5Bx1xhalfxhalf + convalp3 = self.fusion_next_convalp3(output3) + tanhalp3 = self.fusion_next_tanhalp3(convalp3) + gate4 = self.fusion_next_gate4(tanhalp3) + + # 5Bx256xhalfxhalf + output4 = tanh3 * gate4 + + # (5B, half) + shape = cut_mask_list[2].shape + mask = torch.reshape(cut_mask_list[2], (shape[0], 1, shape[1], 1)) + # (5B, half) + shape = cut_caption_masks_list[2].shape + caption_masks = torch.reshape(cut_caption_masks_list[2], (shape[0], 1, 1, shape[1])) + + # 5Bx256xhalfxhalf + output4 = output4 * mask * caption_masks + + + # 5B + valid = torch.sum(cut_mask_list[2], 1) * torch.sum(cut_caption_masks_list[2], 1) + sum_state = torch.sum(output4, (2, 3)) / torch.unsqueeze(valid, 1) + + return sum_state + + + def final(self, fusion_next): + # 5Bx256 + a = self.final_fc1(fusion_next) + a = self.final_tanh1(a) + a = self.final_bn1(a) + + # 5Bx256 + a = self.final_fc2(a) + a = self.final_tanh2(a) + a = self.final_bn2(a) + + # 5Bx128 + a = self.final_fc3(a) + a = self.final_tanh3(a) + a = self.final_bn3(a) + + # 5Bx1 + a = self.final_fc4(a) + a = self.final_bn4(a) + + return torch.reshape(-a, (-1, 5)) + + + def parse_sentences(self, word2idx, mc, max_length): + import numpy as np + def sentence_to_words(sentence): + from models.jsfusion.data_util import clean_str + try: + words = clean_str(sentence).split() + except: + print('[ERROR] sentence is broken: ' + sentence) + sys.exit(1) + + for w in words: + if not w: + continue + yield w + + def sentence_to_matrix(word2idx, sentence, max_length): + indices = [word2idx[w] for w in + sentence_to_words(sentence) + if w in word2idx] + length = min(len(indices), max_length) + return indices[:length] + + with open(mc, 'r') as f: + sentences = [sentence_to_matrix(word2idx, f.readline().strip(), max_length) + for _ in range(5)] + + sentences_tmps = [] + for sent in sentences: + sentences_tmps.append(sent + [0] * (max_length - len(sent))) + sentences = sentences_tmps + + sentence_masks = [] + for sent in sentences: + sentence_masks.append([(1 if s != 0 else 0) for s in sent]) + + sentences = np.asarray(sentences, dtype=np.int32) + sentences = np.reshape(sentences, (1, 5, -1)) + + sentence_masks = np.asarray(sentence_masks, dtype=np.float32) + sentence_masks = np.reshape(sentence_masks, (1, 5, -1)) + + return sentences, sentence_masks + + + def forward(self, tensors): + """Main inference function for the model. + resnet_output: torch.Tensor device='cuda' shape=(BxLx2048) dtype=float32 + """ + (resnet_output,), filename = tensors + + self.mask = self.mask.squeeze(0) + if resnet_output.shape[0] < self.num_frames: + more_zeros = self.num_frames - resnet_output.shape[0] + self.mask[:more_zeros] = 0. + + elif resnet_output.shape[0] > self.num_frames: + print('Movie %s is over %d frames' % (video, self.num_frames)) + self.mask = self.mask[:self.num_frames] + + # mask: torch.Tensor shape=(BxL) dtype=float32 + self.mask = torch.unsqueeze(self.mask, 0) + + mc_path = os.path.join(os.environ['LSMDC_PATH'], 'texts', os.path.splitext(filename)[0] + '.txt') + + # sentences: np.ndarray shape=(Bx5xL) dtype=int32 + # sentence_masks: np.ndarray shape=(Bx5xL) dtype=float32 + sentences, sentence_masks = self.parse_sentences(self.word2idx, mc_path, self.num_frames) + sentences = torch.tensor(sentences, dtype=torch.long).to(self.device) + sentence_masks = torch.tensor(sentence_masks, dtype=torch.float32).to(self.device) + + # Bx512xL + d1v = self.video_embeddings(resnet_output, self.mask) + + # 5Bx512xL + self.embeddings = torch.nn.Embedding.from_pretrained(self.embedding_matrix, freeze=False) + d1w = self.word_embeddings(sentences, sentence_masks) + + # 5Bx512xLxL + fusion = self.fusion(d1v, d1w, self.mask, sentence_masks) + + # 5Bx256 + fusion_next = self.fusion_next(fusion, self.mask, sentence_masks) + + # Bx5 + logits = self.final(fusion_next) + + # B + winners = torch.argmax(logits, dim=1) + + return ((winners,), None) diff --git a/models/jsfusion/sampler.py b/models/jsfusion/sampler.py new file mode 100644 index 0000000..02a86a8 --- /dev/null +++ b/models/jsfusion/sampler.py @@ -0,0 +1,15 @@ +from nvvl import Sampler +import numpy as np + +class FixedSampler(Sampler): + def __init__(self, num_frames): + self.num_frames = num_frames + + def _sample(self, length, num_frames): + if length <= self.num_frames: + return range(length) + else: + return np.linspace(0, length, self.num_frames, endpoint=False, dtype=np.int32) + + def sample(self, length): + return self._sample(length, self.num_frames) From b83d92ee5ff8818f3b05b77a3e493c9d7c5c76cf Mon Sep 17 00:00:00 2001 From: Yunseong Lee Date: Thu, 18 Jul 2019 14:36:28 +0000 Subject: [PATCH 2/3] Fix type mismatch in JSFusionLoader --- models/jsfusion/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/jsfusion/model.py b/models/jsfusion/model.py index 50eee1e..4c1dcdd 100644 --- a/models/jsfusion/model.py +++ b/models/jsfusion/model.py @@ -70,7 +70,7 @@ def __call__(self, input): # frames: (40, 3, 224, 224) filename = os.path.basename(file_path) - out = (frames, filename) + out = ((frames,), filename) return out def __del__(self): From fa3818e2634729b64e860620b92b2b1f49bf3505 Mon Sep 17 00:00:00 2001 From: Yunseong Lee Date: Thu, 18 Jul 2019 14:37:20 +0000 Subject: [PATCH 3/3] Make sentence parsing efficient --- models/jsfusion/module.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/models/jsfusion/module.py b/models/jsfusion/module.py index d6cb418..e26e4d8 100644 --- a/models/jsfusion/module.py +++ b/models/jsfusion/module.py @@ -434,27 +434,18 @@ def sentence_to_matrix(word2idx, sentence, max_length): indices = [word2idx[w] for w in sentence_to_words(sentence) if w in word2idx] - length = min(len(indices), max_length) - return indices[:length] + if len(indices) >= max_length: + return indices[:len(indices)] + else: + return indices + [0] * (max_length - len(indices)) with open(mc, 'r') as f: sentences = [sentence_to_matrix(word2idx, f.readline().strip(), max_length) for _ in range(5)] - sentences_tmps = [] - for sent in sentences: - sentences_tmps.append(sent + [0] * (max_length - len(sent))) - sentences = sentences_tmps - - sentence_masks = [] - for sent in sentences: - sentence_masks.append([(1 if s != 0 else 0) for s in sent]) - sentences = np.asarray(sentences, dtype=np.int32) sentences = np.reshape(sentences, (1, 5, -1)) - - sentence_masks = np.asarray(sentence_masks, dtype=np.float32) - sentence_masks = np.reshape(sentence_masks, (1, 5, -1)) + sentence_masks = (sentences > 0).astype(np.float16) return sentences, sentence_masks