Skip to content

Commit

Permalink
Minor cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
yunseong committed Jul 18, 2019
1 parent 6e4d798 commit b4ad5d3
Show file tree
Hide file tree
Showing 5 changed files with 224 additions and 33 deletions.
6 changes: 3 additions & 3 deletions config/jsfusion-whole.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@
"pipeline": [
{
"model": "models.jsfusion.model.JsFusionLoader",
"gpus": [0,1,2]
"gpus": [0]
},
{
"model": "models.jsfusion.model.ResNetRunner",
"gpus": [0,1,2,3,4,5]
"gpus": [0]
},
{
"model": "models.jsfusion.model.MCModelRunner",
"gpus": [4,5]
"gpus": [0]
}
]
}
12 changes: 4 additions & 8 deletions models/jsfusion/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,19 @@ def add_timing_signal_nd(num_frames, video_channels):
shape = [1, num_frames, video_channels]
num_dims = len(shape) - 2
channels = shape[-1]

position = torch.tensor(range(num_frames), dtype=torch.float32)
position = torch.unsqueeze(position, dim=1)

num_timescales = channels // (num_dims * 2)
log_timescale_increment = math.log(MAX_TIMESCALE / MIN_TIMESCALE) / (num_timescales - 1)
inv_timescales = []
for i in range(num_timescales):
inv_timescales.append(1.0 * math.exp(-float(i) * log_timescale_increment))
dim = 0
length = shape[dim + 1]

position = torch.tensor(range(num_frames), dtype=torch.float32)
inv_timescales = torch.tensor(inv_timescales, dtype=torch.float32)

position = torch.unsqueeze(position, dim=1)

inv_timescales = torch.unsqueeze(inv_timescales, dim=0)

scaled_time = position.matmul(inv_timescales)

signal = torch.cat([scaled_time.sin(), scaled_time.cos()], dim=1)
signal = torch.unsqueeze(signal, 0)

Expand Down
211 changes: 211 additions & 0 deletions models/jsfusion/data_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
"""Utility class used in JSFusion model, copied from the original author's code
https://github.com/yj-yu/lsmdc/blob/master/videocap/datasets/data_util.py
"""
import time
import numpy as np
import re


def clean_str(string, downcase=True):
"""Tokenization/string cleaning for strings.
Taken from https://github.com/yoonkim/CNN_sentence/blob/master/process_data.py
"""
string = re.sub(r"[^A-Za-z0-9(),!?\'\`(_____)]", " ", string)
string = re.sub(r"\'s", " \'s", string)
string = re.sub(r"\'ve", " \'ve", string)
string = re.sub(r"n\'t", " n\'t", string)
string = re.sub(r"\'re", " \'re", string)
string = re.sub(r"\'d", " \'d", string)
string = re.sub(r"\'ll", " \'ll", string)
string = re.sub(r",", " , ", string)
string = re.sub(r"!", " ! ", string)
string = re.sub(r"\(", " \( ", string)
string = re.sub(r"\)", " \) ", string)
string = re.sub(r"\?", " \? ", string)
string = re.sub(r"\s{2,}", " ", string)
return string.strip().lower() if downcase else string.strip()

def recover_word(string):
string = re.sub(r" \'s", "\'s", string)
string = re.sub(r" ,", ",", string)
return string

def clean_blank(blank_sent):
"""Tokenizes and changes _____ to <START>
<START> would be Answer position in FIB work.
"""
clean_sent = clean_str(blank_sent).split()
return ['<START>' if x == '_____' else x for x in clean_sent]


def clean_root(string):
"""Removes unexpected character in root.
"""
return string


def pad_sequences(sequences, pad_token="[PAD]", pad_location="LEFT", max_length=None):
"""Pads all sequences to the same length. The length is defined by the longest sequence.
Returns padded sequences.
"""
if not max_length:
max_length = max(len(x) for x in sequences)

result = []
for i in range(len(sequences)):
sentence = sequences[i]
num_padding = max_length - len(sentence)
if num_padding == 0:
new_sentence = sentence
elif num_padding < 0:
new_sentence = sentence[:num_padding]
elif pad_location == "RIGHT":
new_sentence = sentence + [pad_token] * num_padding
elif pad_location == "LEFT":
new_sentence = [pad_token] * num_padding + sentence
else:
print("Invalid pad_location. Specify LEFT or RIGHT.")
result.append(new_sentence)
return result


def convert_sent_to_index(sentence, word_to_index):
"""Converts sentence consisting of string to indexed sentence.
"""
return [word_to_index[word] if word in word_to_index.keys() else 0 for word in sentence]


def batch_iter(data, batch_size, seed=None, fill=True):
"""Generates a batch iterator for a dataset.
"""
random = np.random.RandomState(seed)
data_length = len(data)
num_batches = int(data_length / batch_size)
if data_length % batch_size != 0:
num_batches += 1

# Shuffle the data at each epoch
shuffle_indices = random.permutation(np.arange(data_length))
for batch_num in range(num_batches):
start_index = batch_num * batch_size
end_index = min((batch_num + 1) * batch_size, data_length)
selected_indices = shuffle_indices[start_index:end_index]
# If we don't have enough data left for a whole batch, fill it randomly
if fill and end_index >= data_length:
num_missing = batch_size - len(selected_indices)
selected_indices = np.concatenate([selected_indices, random.randint(0, data_length, num_missing)])
yield [data[i] for i in selected_indices]


def fsr_iter(fsr_data, batch_size, random_seed=42, fill=True):
"""fsr_data: one of LSMDCData.build_data(), [[video_features], [sentences], [roots]]
return per iter: [[feature]*batch_size, [sentences]*batch_size, [roots]*batch]
Usage:
train_data, val_data, test_data = LSMDCData.build_data()
for features, sentences, roots in fsr_iter(train_data, 20, 10):
feed_dict = {model.video_feature : features,
model.sentences : sentences,
model.roots : roots}
"""

train_iter = batch_iter(list(zip(*fsr_data)), batch_size, fill=fill, seed=random_seed)
return map(lambda batch: zip(*batch), train_iter)


def preprocess_sents(descriptions, word_to_index, max_length):
descriptions = [clean_str(sent).split() for sent in descriptions]
# Add padding on the right to each sentence in order to keep the same lengths.
descriptions = pad_sequences(descriptions, max_length=max_length)
# Convert sentences from a list of string to the list of indices (int)
descriptions = [convert_sent_to_index(sent, word_to_index) for sent in descriptions]

return descriptions
# remove punctuation mark and special chars from root.


def preprocess_roots(roots, word_to_index):
roots = [clean_root(root) for root in roots]
# convert string to int index.
roots = [word_to_index[root] if root in word_to_index.keys() else 0 for root in roots]

return roots


def pad_video(video_feature, dimension, padded_feature=None):
"""Fills pad to video to have same length.
Pad in Left.
video = [pad,..., pad, frm1, frm2, ..., frmN]
"""
if padded_feature is None:
padded_feature = np.zeros(dimension, dtype=np.float32)
max_length = dimension[0]
current_length = video_feature.shape[0]
num_padding = max_length - current_length
if num_padding == 0:
padded_feature[:] = video_feature
elif num_padding < 0:
steps = np.linspace(0, current_length, num=max_length, endpoint=False, dtype=np.int32)
padded_feature[:] = video_feature[steps]
else:
# about 0.7 sec
padded_feature[num_padding:] = video_feature

return padded_feature

def repeat_pad_video(video_feature, dimension):
padded_feature = np.zeros(dimension, dtype= np.float)
max_length = dimension[0]
current_length = video_feature.shape[0]

if current_length == max_length:
padded_feature[:] = video_feature

elif current_length < max_length:
tile_num = int(max_length / current_length)
to_tile = np.ones(len(dimension), dtype=np.int32)
to_tile[0] = tile_num
remainder = max_length % current_length
tiled_vid = np.tile(video_feature, to_tile)
if remainder > 0:
padded_feature[0:remainder] = video_feature[-remainder:]
padded_feature[remainder:] = tiled_vid

else:
steps = np.linspace(0, current_length, num=max_length, endpoint=False, dtype=np.int32)
padded_feature[:] = video_feature[steps]
return padded_feature

def stretch_pad_video(video_feature, dimension):
padded_feature = np.zeros(dimension, dtype= np.float)
max_length = dimension[0]
current_length = video_feature.shape[0]

if current_length == max_length:
padded_feature[:] = video_feature
elif current_length < max_length:
repeat_num = int((max_length-1) / current_length)+1
tiled_vid = np.repeat(video_feature, repeat_num,0)
steps = np.linspace(0, repeat_num*current_length, num=max_length, endpoint=False, dtype=np.int32)
padded_feature[:] = tiled_vid[steps]
else:
steps = np.linspace(0, current_length, num=max_length, endpoint=False, dtype=np.int32)
padded_feature[:] = video_feature[steps]
return padded_feature


def fill_mask(max_length, current_length, zero_location='LEFT'):
num_padding = max_length - current_length
if num_padding <= 0:
mask = np.ones(max_length)
elif zero_location == 'LEFT':
mask = np.ones(max_length)
for i in range(num_padding):
mask[i] = 0
elif zero_location == 'RIGHT':
mask = np.zeros(max_length)
for i in range(current_length):
mask[i] = 1

return mask
6 changes: 0 additions & 6 deletions models/jsfusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
import nvvl
import os

NUM_FRAMES=40

class JsFusionVideoPathIterator(VideoPathIterator):
def __init__(self):
super(JsFusionVideoPathIterator, self).__init__()
Expand Down Expand Up @@ -60,11 +58,8 @@ def __call__(self, input):
frames = frames.float()
frames = frames.permute(0, 2, 1, 3, 4)

### TODO Directly apply this transform
transform = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])

### TODO This logic seems to be simplified
frames_tmp = []
for frame in frames:
frame = torch.squeeze(frame)
Expand Down Expand Up @@ -115,7 +110,6 @@ def __init__(self, device, num_frames = 40):
self.model.eval()

def input_shape(self):
# TODO Input shape
return ((1, 40, 2048),)

def __call__(self, input):
Expand Down
22 changes: 6 additions & 16 deletions models/jsfusion/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ def __init__(self, device, dropout_prob = 0.5, video_channels = 2048, num_frames
self.final_bn3 = torch.nn.BatchNorm1d(128, eps=0.001, momentum=0.001)
self.final_fc4 = torch.nn.Linear(128, 1)
self.final_bn4 = torch.nn.BatchNorm1d(1, eps=0.001, momentum=0.001)

self.word2idx = hkl.load(os.path.join(os.environ['LSMDC_PATH'], 'hkls/common_word_to_index_py3.hkl'))



def video_embeddings(self, video, mask):
# BxLxC
embedded_feat_tmp = video + self.signal
Expand All @@ -144,7 +144,6 @@ def video_embeddings(self, video, mask):
relu2 = self.relu2(conv2)
bn2 = self.bn2(relu2)


conv3 = self.conv3(bn2)
relu3 = self.relu3(conv3)
bn3 = self.bn3(relu3)
Expand Down Expand Up @@ -185,7 +184,6 @@ def word_embeddings(self, captions, caption_masks):

# 5BxLxH
embedded_sentence = seq_embeddings * caption_masks
print('embedded_sentence', embedded_sentence.size, embedded_sentence.device, embedded_sentence)

# 5BxLx1024
outputs, _ = self.lstm(embedded_sentence)
Expand All @@ -202,7 +200,6 @@ def word_embeddings(self, captions, caption_masks):
return rnn_output



def fusion(self, v, w, mask, caption_masks):
# 5Bx512xL
v = v.repeat(5, 1, 1)
Expand Down Expand Up @@ -302,20 +299,17 @@ def fusion_next(self, output1, mask, caption_masks):

cut_mask_indices = [i for i in range(cut_mask.shape[1]) if i % 2 == 1 and i < cut_mask.shape[1] - 1]
cut_mask_indices = torch.tensor(cut_mask_indices)
cut_mask_indices = cut_mask_indices.to(device=self.device, non_blocking=True)#cuda(non_blocking=True)
cut_mask_indices = cut_mask_indices.to(device=self.device, non_blocking=True)

# cut_mask = torch.tensor([([0]*(max_len - l) + [1]*l) for l in cut_mask_len.cpu().numpy()],
# dtype=torch.float32)
# cut_mask = torch.tensor([([0]*(max_len - l) + [1]*l) for l in cut_mask_len.cpu().numpy()], dtype=torch.float32)
cut_mask = torch.index_select(cut_mask, 1, cut_mask_indices)

cut_caption_masks_indices = [i for i in range(cut_caption_masks.shape[1]) if i % 2 == 1 and i > 1]
cut_caption_masks_indices = torch.tensor(cut_caption_masks_indices)
cut_caption_masks_indices = cut_caption_masks_indices.to(device=self.device, non_blocking=True)#cuda(non_blocking=True)

cut_caption_masks_indices = cut_caption_masks_indices.to(device=self.device, non_blocking=True)

# cut_caption_masks = torch.tensor([([1]*l + [0]*(max_len - l)) for l in cut_caption_masks_len.cpu().numpy()],
# dtype=torch.float32)

# cut_caption_masks = torch.tensor([([1]*l + [0]*(max_len - l)) for l in cut_caption_masks_len.cpu().numpy()], dtype=torch.float32)
cut_caption_masks = torch.index_select(cut_caption_masks, 1, cut_caption_masks_indices)

cut_mask_list.append(cut_mask.repeat(5, 1))
Expand Down Expand Up @@ -398,7 +392,6 @@ def fusion_next(self, output1, mask, caption_masks):
return sum_state



def final(self, fusion_next):
# 5Bx256
a = self.final_fc1(fusion_next)
Expand All @@ -420,9 +413,6 @@ def final(self, fusion_next):
a = self.final_bn4(a)

return torch.reshape(-a, (-1, 5))





def parse_sentences(self, word2idx, mc, max_length):
Expand Down

0 comments on commit b4ad5d3

Please sign in to comment.