Skip to content

Commit

Permalink
Make corpus shuffling less memory hungry
Browse files Browse the repository at this point in the history
  • Loading branch information
pjwilliams committed Oct 31, 2019
1 parent a3c2740 commit 736f123
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 32 deletions.
120 changes: 89 additions & 31 deletions data/shuffle.py
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,50 +1,108 @@
#!/usr/bin/env python3

import math
import os
import sys
import random

import sys
import tempfile
from subprocess import call


def main(files, temporary=False):
# TODO Make CHUNK_SIZE user configurable?
CHUNK_SIZE = 10000000 # Number of lines.

fds = [open(ff, encoding="UTF-8") for ff in files]
def jointly_shuffle_files(files, temporary=False):
"""Randomly shuffle the given files, applying the same permutation to each.
lines = []
for l in fds[0]:
line = [l.strip()] + [ff.readline().strip() for ff in fds[1:]]
lines.append(line)
Since the same permutation is applied to all input files, they must
contain the same number of input lines.
[ff.close() for ff in fds]
If 'temporary' is True then the shuffled files are written to temporary
files. Otherwise, the shuffled files are written to files with the same
paths as the originals, but with the added suffix '.shuf'.
random.shuffle(lines)
In addition to shuffling the files, any leading or trailing whitespace is
removed from each line.
if temporary:
fds = []
for ff in files:
path, filename = os.path.split(os.path.realpath(ff))
fd = tempfile.TemporaryFile(prefix=filename+'.shuf',
dir=path,
mode='w+',
encoding="UTF-8")
fds.append(fd)
else:
fds = [open(ff+'.shuf', mode='w', encoding="UTF-8") for ff in files]
In order to handle large files, the input files are not read into memory
in full, but instead are read in chunks of size CHUNK_SIZE.
Args:
files: a list of strings specifying the paths of the input files.
temporary: a Boolean (see description above).
Returns:
A list containing a file object for each shuffled file, in the same
order as the input files. Each file object is open and positioned at
the start of the file.
"""

# Determine the number of lines (should be the same for all files).
total_lines = 0
for _ in open(files[0]):
total_lines += 1

# Randomly permute the list of line numbers.
perm = list(range(total_lines))
random.shuffle(perm)

# Convert the list of line numbers to a list of chunk indices and offsets.
ordering = [(i // CHUNK_SIZE, i % CHUNK_SIZE) for i in perm]

for l in lines:
for ii, fd in enumerate(fds):
print(l[ii], file=fd)
# Sort each file according to the generated ordering.
return [_sort_file(path, ordering, temporary) for path in files]


def _sort_file(path, ordering, temporary):

# Open a temporary file for each chunk.

num_chunks = math.ceil(len(ordering) / CHUNK_SIZE)
dirname, filename = os.path.split(os.path.realpath(path))
chunk_files = [tempfile.TemporaryFile(prefix=filename+'.chunk'+str(i),
dir=dirname, mode='w+',
encoding="UTF-8")
for i in range(num_chunks)]

# Read one chunk at a time from path and write the lines to the temporary
# files in the order specified by ordering.

def _write_chunk_in_order(chunk, chunk_num, out_file):
for i, j in ordering:
if i == chunk_num:
out_file.write(chunk[j] + '\n')

chunk = []
chunk_num = 0
for i, line in enumerate(open(path)):
if i > 0 and (i % CHUNK_SIZE) == 0:
_write_chunk_in_order(chunk, chunk_num, chunk_files[chunk_num])
chunk = []
chunk_num += 1
chunk.append(line.strip())
if chunk:
_write_chunk_in_order(chunk, chunk_num, chunk_files[chunk_num])

# Open the output file.
if temporary:
[ff.seek(0) for ff in fds]
out_file = tempfile.TemporaryFile(prefix=filename+'.shuf', dir=dirname,
mode='w+', encoding='UTF-8')
else:
[ff.close() for ff in fds]
out_file = open(path+'.shuf', mode='w', encoding='UTF-8')

return fds
# Seek to the start of the chunk files.
for chunk_file in chunk_files:
chunk_file.seek(0)

if __name__ == '__main__':
main(sys.argv[1:])
# Write the output.
for i, _ in ordering:
line = chunk_files[i].readline()
out_file.write(line)


# Seek to the start so that the file object is ready for reading.
out_file.seek(0)

return out_file


if __name__ == '__main__':
jointly_shuffle_files(sys.argv[1:])
3 changes: 2 additions & 1 deletion nematus/data_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def __init__(self, source, target,
elif shuffle_each_epoch:
self.source_orig = source
self.target_orig = target
self.source, self.target = shuffle.main([self.source_orig, self.target_orig], temporary=True)
self.source, self.target = shuffle.jointly_shuffle_files(
[self.source_orig, self.target_orig], temporary=True)
else:
self.source = fopen(source, 'r')
self.target = fopen(target, 'r')
Expand Down

0 comments on commit 736f123

Please sign in to comment.