-
Notifications
You must be signed in to change notification settings - Fork 250
/
data.py
36 lines (28 loc) · 861 Bytes
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import os
from collections import Counter
def read_data(fname, count, word2idx):
if os.path.isfile(fname):
with open(fname) as f:
lines = f.readlines()
else:
raise("[!] Data %s not found" % fname)
words = []
for line in lines:
words.extend(line.split())
if len(count) == 0:
count.append(['<eos>', 0])
count[0][1] += len(lines)
count.extend(Counter(words).most_common())
if len(word2idx) == 0:
word2idx['<eos>'] = 0
for word, _ in count:
if word not in word2idx:
word2idx[word] = len(word2idx)
data = list()
for line in lines:
for word in line.split():
index = word2idx[word]
data.append(index)
data.append(word2idx['<eos>'])
print("Read %s words from %s" % (len(data), fname))
return data