forked from EdinburghNLP/nematus
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Make corpus shuffling less memory hungry
- Loading branch information
1 parent
a3c2740
commit 736f123
Showing
2 changed files
with
91 additions
and
32 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,50 +1,108 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import math | ||
import os | ||
import sys | ||
import random | ||
|
||
import sys | ||
import tempfile | ||
from subprocess import call | ||
|
||
|
||
def main(files, temporary=False): | ||
# TODO Make CHUNK_SIZE user configurable? | ||
CHUNK_SIZE = 10000000 # Number of lines. | ||
|
||
fds = [open(ff, encoding="UTF-8") for ff in files] | ||
def jointly_shuffle_files(files, temporary=False): | ||
"""Randomly shuffle the given files, applying the same permutation to each. | ||
lines = [] | ||
for l in fds[0]: | ||
line = [l.strip()] + [ff.readline().strip() for ff in fds[1:]] | ||
lines.append(line) | ||
Since the same permutation is applied to all input files, they must | ||
contain the same number of input lines. | ||
[ff.close() for ff in fds] | ||
If 'temporary' is True then the shuffled files are written to temporary | ||
files. Otherwise, the shuffled files are written to files with the same | ||
paths as the originals, but with the added suffix '.shuf'. | ||
random.shuffle(lines) | ||
In addition to shuffling the files, any leading or trailing whitespace is | ||
removed from each line. | ||
if temporary: | ||
fds = [] | ||
for ff in files: | ||
path, filename = os.path.split(os.path.realpath(ff)) | ||
fd = tempfile.TemporaryFile(prefix=filename+'.shuf', | ||
dir=path, | ||
mode='w+', | ||
encoding="UTF-8") | ||
fds.append(fd) | ||
else: | ||
fds = [open(ff+'.shuf', mode='w', encoding="UTF-8") for ff in files] | ||
In order to handle large files, the input files are not read into memory | ||
in full, but instead are read in chunks of size CHUNK_SIZE. | ||
Args: | ||
files: a list of strings specifying the paths of the input files. | ||
temporary: a Boolean (see description above). | ||
Returns: | ||
A list containing a file object for each shuffled file, in the same | ||
order as the input files. Each file object is open and positioned at | ||
the start of the file. | ||
""" | ||
|
||
# Determine the number of lines (should be the same for all files). | ||
total_lines = 0 | ||
for _ in open(files[0]): | ||
total_lines += 1 | ||
|
||
# Randomly permute the list of line numbers. | ||
perm = list(range(total_lines)) | ||
random.shuffle(perm) | ||
|
||
# Convert the list of line numbers to a list of chunk indices and offsets. | ||
ordering = [(i // CHUNK_SIZE, i % CHUNK_SIZE) for i in perm] | ||
|
||
for l in lines: | ||
for ii, fd in enumerate(fds): | ||
print(l[ii], file=fd) | ||
# Sort each file according to the generated ordering. | ||
return [_sort_file(path, ordering, temporary) for path in files] | ||
|
||
|
||
def _sort_file(path, ordering, temporary): | ||
|
||
# Open a temporary file for each chunk. | ||
|
||
num_chunks = math.ceil(len(ordering) / CHUNK_SIZE) | ||
dirname, filename = os.path.split(os.path.realpath(path)) | ||
chunk_files = [tempfile.TemporaryFile(prefix=filename+'.chunk'+str(i), | ||
dir=dirname, mode='w+', | ||
encoding="UTF-8") | ||
for i in range(num_chunks)] | ||
|
||
# Read one chunk at a time from path and write the lines to the temporary | ||
# files in the order specified by ordering. | ||
|
||
def _write_chunk_in_order(chunk, chunk_num, out_file): | ||
for i, j in ordering: | ||
if i == chunk_num: | ||
out_file.write(chunk[j] + '\n') | ||
|
||
chunk = [] | ||
chunk_num = 0 | ||
for i, line in enumerate(open(path)): | ||
if i > 0 and (i % CHUNK_SIZE) == 0: | ||
_write_chunk_in_order(chunk, chunk_num, chunk_files[chunk_num]) | ||
chunk = [] | ||
chunk_num += 1 | ||
chunk.append(line.strip()) | ||
if chunk: | ||
_write_chunk_in_order(chunk, chunk_num, chunk_files[chunk_num]) | ||
|
||
# Open the output file. | ||
if temporary: | ||
[ff.seek(0) for ff in fds] | ||
out_file = tempfile.TemporaryFile(prefix=filename+'.shuf', dir=dirname, | ||
mode='w+', encoding='UTF-8') | ||
else: | ||
[ff.close() for ff in fds] | ||
out_file = open(path+'.shuf', mode='w', encoding='UTF-8') | ||
|
||
return fds | ||
# Seek to the start of the chunk files. | ||
for chunk_file in chunk_files: | ||
chunk_file.seek(0) | ||
|
||
if __name__ == '__main__': | ||
main(sys.argv[1:]) | ||
# Write the output. | ||
for i, _ in ordering: | ||
line = chunk_files[i].readline() | ||
out_file.write(line) | ||
|
||
|
||
# Seek to the start so that the file object is ready for reading. | ||
out_file.seek(0) | ||
|
||
return out_file | ||
|
||
|
||
if __name__ == '__main__': | ||
jointly_shuffle_files(sys.argv[1:]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters