diff --git a/data/shuffle.py b/data/shuffle.py old mode 100644 new mode 100755 index ce417fdc..5661a488 --- a/data/shuffle.py +++ b/data/shuffle.py @@ -1,50 +1,108 @@ +#!/usr/bin/env python3 + +import math import os -import sys import random - +import sys import tempfile -from subprocess import call -def main(files, temporary=False): +# TODO Make CHUNK_SIZE user configurable? +CHUNK_SIZE = 10000000 # Number of lines. - fds = [open(ff, encoding="UTF-8") for ff in files] +def jointly_shuffle_files(files, temporary=False): + """Randomly shuffle the given files, applying the same permutation to each. - lines = [] - for l in fds[0]: - line = [l.strip()] + [ff.readline().strip() for ff in fds[1:]] - lines.append(line) + Since the same permutation is applied to all input files, they must + contain the same number of input lines. - [ff.close() for ff in fds] + If 'temporary' is True then the shuffled files are written to temporary + files. Otherwise, the shuffled files are written to files with the same + paths as the originals, but with the added suffix '.shuf'. - random.shuffle(lines) + In addition to shuffling the files, any leading or trailing whitespace is + removed from each line. - if temporary: - fds = [] - for ff in files: - path, filename = os.path.split(os.path.realpath(ff)) - fd = tempfile.TemporaryFile(prefix=filename+'.shuf', - dir=path, - mode='w+', - encoding="UTF-8") - fds.append(fd) - else: - fds = [open(ff+'.shuf', mode='w', encoding="UTF-8") for ff in files] + In order to handle large files, the input files are not read into memory + in full, but instead are read in chunks of size CHUNK_SIZE. + + Args: + files: a list of strings specifying the paths of the input files. + temporary: a Boolean (see description above). + + Returns: + A list containing a file object for each shuffled file, in the same + order as the input files. Each file object is open and positioned at + the start of the file. + """ + + # Determine the number of lines (should be the same for all files). + total_lines = 0 + for _ in open(files[0]): + total_lines += 1 + + # Randomly permute the list of line numbers. + perm = list(range(total_lines)) + random.shuffle(perm) + + # Convert the list of line numbers to a list of chunk indices and offsets. + ordering = [(i // CHUNK_SIZE, i % CHUNK_SIZE) for i in perm] - for l in lines: - for ii, fd in enumerate(fds): - print(l[ii], file=fd) + # Sort each file according to the generated ordering. + return [_sort_file(path, ordering, temporary) for path in files] + +def _sort_file(path, ordering, temporary): + + # Open a temporary file for each chunk. + + num_chunks = math.ceil(len(ordering) / CHUNK_SIZE) + dirname, filename = os.path.split(os.path.realpath(path)) + chunk_files = [tempfile.TemporaryFile(prefix=filename+'.chunk'+str(i), + dir=dirname, mode='w+', + encoding="UTF-8") + for i in range(num_chunks)] + + # Read one chunk at a time from path and write the lines to the temporary + # files in the order specified by ordering. + + def _write_chunk_in_order(chunk, chunk_num, out_file): + for i, j in ordering: + if i == chunk_num: + out_file.write(chunk[j] + '\n') + + chunk = [] + chunk_num = 0 + for i, line in enumerate(open(path)): + if i > 0 and (i % CHUNK_SIZE) == 0: + _write_chunk_in_order(chunk, chunk_num, chunk_files[chunk_num]) + chunk = [] + chunk_num += 1 + chunk.append(line.strip()) + if chunk: + _write_chunk_in_order(chunk, chunk_num, chunk_files[chunk_num]) + + # Open the output file. if temporary: - [ff.seek(0) for ff in fds] + out_file = tempfile.TemporaryFile(prefix=filename+'.shuf', dir=dirname, + mode='w+', encoding='UTF-8') else: - [ff.close() for ff in fds] + out_file = open(path+'.shuf', mode='w', encoding='UTF-8') - return fds + # Seek to the start of the chunk files. + for chunk_file in chunk_files: + chunk_file.seek(0) -if __name__ == '__main__': - main(sys.argv[1:]) + # Write the output. + for i, _ in ordering: + line = chunk_files[i].readline() + out_file.write(line) - + # Seek to the start so that the file object is ready for reading. + out_file.seek(0) + return out_file + +if __name__ == '__main__': + jointly_shuffle_files(sys.argv[1:]) diff --git a/nematus/data_iterator.py b/nematus/data_iterator.py index 01efa871..5814e4f6 100644 --- a/nematus/data_iterator.py +++ b/nematus/data_iterator.py @@ -61,7 +61,8 @@ def __init__(self, source, target, elif shuffle_each_epoch: self.source_orig = source self.target_orig = target - self.source, self.target = shuffle.main([self.source_orig, self.target_orig], temporary=True) + self.source, self.target = shuffle.jointly_shuffle_files( + [self.source_orig, self.target_orig], temporary=True) else: self.source = fopen(source, 'r') self.target = fopen(target, 'r')