Skip to content

Commit

Permalink
Merge pull request EdinburghNLP#107 from zippotju/patch
Browse files Browse the repository at this point in the history
Solve the limitation of MRT
  • Loading branch information
pjwilliams authored Oct 25, 2019
2 parents b92897b + 4c0bb5e commit 637d272
Show file tree
Hide file tree
Showing 5 changed files with 328 additions and 340 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
build
dist
nmt.egg-info
.idea
.DS_Store
118 changes: 70 additions & 48 deletions nematus/model_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ def __init__(self, config, num_gpus, replicas, optimizer, global_step,
beam_size=config.samplesN)
else:
assert config.sample_way == 'randomly_sample'
# TODO Set beam_size to config.samplesN instead of using
# np.repeat to expand input in full_sampler()?
# Set beam_size to config.samplesN instead of using
# np.repeat to expand input in full_sampler()
self._mrt_sampler = RandomSampler(
models=[replicas[0]],
configs=[config],
beam_size=1)
beam_size=config.samplesN)


def update(self, session, x, x_mask, y, y_mask, num_to_target,
Expand Down Expand Up @@ -88,6 +88,7 @@ def update(self, session, x, x_mask, y, y_mask, num_to_target,
# number of replicas, since each replica has to receive some input (the
# dummy sub-batches will have a weight of zero).

index = None
if self._config.loss_function == 'MRT':
# Generate candidate sentences (sampling) based on source sentences in each minibatch
# outputs are 'sampleN' times larger than inputs
Expand All @@ -104,36 +105,18 @@ def update(self, session, x, x_mask, y, y_mask, num_to_target,
y = numpy.array(y)
y_mask = numpy.array(y_mask)

# limitation (could be solved later):
# x, x_mask, y, y_mask would be split into sub-batches. However, to make sure the subspace distribution of
# each source sentence is normalised properly, the candidate sentences of each source sentence must be
# split into same sub-batch. Therefore:
# 1. Under beam search sampling strategy, the number of sampled candidates is exact equal to set 'sampleN'.
# Only need to make sure the space of sub-batch is an integral multiple of 'sampleN'.
# 2. Under randomly sampling strategy, to keep high efficiency, the duplicate candidates are deleted without
# filling to full 'sampleN' candidates. Hence, the minibatch cannot be split.

if self._config.sample_way == 'beam_search':
if len(self._replicas) > 1:
assert self._config.max_sentences_per_device >= self._config.samplesN
assert self._config.max_sentences_per_device % self._config.samplesN == 0
else:
assert self._config.max_sentences_per_device == 0
assert self._config.gradient_aggregation_steps == 1
assert len(self._replicas) == 1

if (self._config.max_sentences_per_device != 0
or self._config.max_tokens_per_device != 0):
start_points = self._split_minibatch_for_device_size(
x_mask, y_mask, self._config.max_sentences_per_device,
self._config.max_tokens_per_device)
self._config.max_tokens_per_device, index)
else:
n = len(self._replicas) * self._config.gradient_aggregation_steps
start_points = self._split_minibatch_into_n(x_mask, y_mask, n)

if self._config.loss_function == 'MRT':
split_x, split_x_mask, split_y, split_y_mask, split_score, weights = \
self._split_and_pad_minibatch_mrt(x, x_mask, y, y_mask, score, start_points)
split_x, split_x_mask, split_y, split_y_mask, split_score, weights, split_index = \
self._split_and_pad_minibatch_mrt(x, x_mask, y, y_mask, score, start_points, index)
else:
split_x, split_x_mask, split_y, split_y_mask, weights = \
self._split_and_pad_minibatch(x, x_mask, y, y_mask, start_points)
Expand Down Expand Up @@ -168,9 +151,8 @@ def update(self, session, x, x_mask, y, y_mask, num_to_target,
if self._config.loss_function == 'MRT':
# convert evaluation score of each candidates into tensor for subsequent expected risk calculations
feed_dict[self._replicas[j].inputs.scores] = split_score[i + j]
if self._config.sample_way == 'randomly_sample':
# also convey information of starting point of each source sentences to later calculation
feed_dict[self._replicas[j].inputs.index] = index[i + j] #index[0]
# also convey information of starting point of each source sentences to later calculation
feed_dict[self._replicas[j].inputs.index] = split_index[i + j]
feed_dict[self._replicas[j].inputs.training] = True

if self._config.print_per_token_pro == False:
Expand Down Expand Up @@ -252,7 +234,7 @@ def _split_minibatch_into_n(self, x_mask, y_mask, n):

def _split_minibatch_for_device_size(self, x_mask, y_mask,
max_sents_per_device=0,
max_tokens_per_device=0):
max_tokens_per_device=0, index=None):
"""Determines how to split a minibatch into device-sized sub-batches.
Either max_sents_per_device or max_tokens_per_device must be given.
Expand All @@ -270,6 +252,8 @@ def _split_minibatch_for_device_size(self, x_mask, y_mask,

assert max_sents_per_device == 0 or max_tokens_per_device == 0
assert not (max_sents_per_device == 0 and max_tokens_per_device == 0)
if index is not None:
s_index = dict(zip(index[0], list(range(len(index[0])))))

source_lengths = numpy.sum(x_mask, axis=0)
target_lengths = numpy.sum(y_mask, axis=0)
Expand All @@ -282,23 +266,52 @@ def _split_minibatch_for_device_size(self, x_mask, y_mask,
start_points = list(range(0, num_sents, max_sents_per_device))
else:
start_points = [0]
while True:
i = start_points[-1]
s_longest = source_lengths[i]
t_longest = target_lengths[i]
next_start_point = None
for j in range(i+1, num_sents):
s_longest = max(s_longest, source_lengths[j])
t_longest = max(t_longest, target_lengths[j])
s_tokens = s_longest * (j-i+1)
t_tokens = t_longest * (j-i+1)
if (s_tokens > max_tokens_per_device
or t_tokens > max_tokens_per_device):
next_start_point = j
if index is None:
while True:
i = start_points[-1]
s_longest = source_lengths[i]
t_longest = target_lengths[i]
next_start_point = None
for j in range(i+1, num_sents):
s_longest = max(s_longest, source_lengths[j])
t_longest = max(t_longest, target_lengths[j])
s_tokens = s_longest * (j-i+1)
t_tokens = t_longest * (j-i+1)
if (s_tokens > max_tokens_per_device
or t_tokens > max_tokens_per_device):
next_start_point = j
break
if next_start_point is None:
break
if next_start_point is None:
break
start_points.append(next_start_point)
start_points.append(next_start_point)
else:
# split the dataset based on index points which indicates the index of each group of candidate
# translations of MRT
while True:
i = start_points[-1]
s_longest = source_lengths[i]
t_longest = target_lengths[i]
next_start_point = None
for j in range(i + 1, num_sents):
s_longest = max(s_longest, source_lengths[j])
t_longest = max(t_longest, target_lengths[j])
s_tokens = s_longest * (j - i + 1)
t_tokens = t_longest * (j - i + 1)
if (s_tokens > max_tokens_per_device
or t_tokens > max_tokens_per_device):
if j in s_index:
next_start_point = j
break
else:
while True:
j -= 1
if j in s_index:
break
next_start_point = j
break
if next_start_point is None:
break
start_points.append(next_start_point)

return start_points

Expand Down Expand Up @@ -376,7 +389,7 @@ def pad(split_a, padding_size):
return split_x, split_x_mask, split_y, split_y_mask, weights


def _split_and_pad_minibatch_mrt(self, x, x_mask, y, y_mask, score, start_points):
def _split_and_pad_minibatch_mrt(self, x, x_mask, y, y_mask, score, start_points, index):
"""Splits a minibatch according to a list of split points (function basically same as _split_and_pad_minibatch),
only difference is that the evaluation score of each sentence would also be split accordingly.
Expand All @@ -399,6 +412,15 @@ def _split_and_pad_minibatch_mrt(self, x, x_mask, y, y_mask, score, start_points
# change shape from batch_size to (1, batch_size)
score = score[numpy.newaxis, :]

# split the index for the subsequent calculation of expected risk
tmp = []
batch_size = x_mask.shape[-1]
start_points_new = start_points + [batch_size]
s_index = dict(zip(index[0], list(range(len(index[0])))))
for i in range(len(start_points_new) - 1):
sub_list = index[0][s_index[start_points_new[i]]:s_index[start_points_new[i + 1]] + 1]
tmp.append(numpy.array([l - start_points_new[i] for l in sub_list]))

def split_array(a, start_points):
batch_size = a.shape[-1]
next_points = start_points[1:] + [batch_size]
Expand Down Expand Up @@ -427,8 +449,7 @@ def trim_arrays(arrays, new_seq_lens):
split_y_mask = trim_arrays(split_y_mask, max_lens)

# number of real sentences(before sampling candidates) of each sub-batch.
weights = [(t.shape[1]/self._config.samplesN)
for t in split_y_mask]
weights = [len(t)-1 for t in tmp]

# Pad the split lists with dummy arrays so that the total number of
# sub-batches is a multiple of the number of replicas.
Expand All @@ -446,12 +467,13 @@ def pad(split_a, padding_size):
pad(split_x_mask, padding_size)
pad(split_y, padding_size)
pad(split_y_mask, padding_size)
pad(tmp, padding_size)

for i in range(padding_size):
weights.append(0.0)
split_score.append(0.0)

return split_x, split_x_mask, split_y, split_y_mask, split_score, weights
return split_x, split_x_mask, split_y, split_y_mask, split_score, weights, tmp


class _ModelUpdateGraph(object):
Expand Down
Loading

0 comments on commit 637d272

Please sign in to comment.