-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathdata_download.py
76 lines (55 loc) · 2.61 KB
/
data_download.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import requests
import tarfile
from glob import glob
# root_dir = os.path.join('','/storage','scratch2','ell_data')
root_dir = os.path.join('data')
if not os.path.exists(root_dir):
os.mkdir(root_dir)
if not os.path.exists(os.path.join(root_dir, 'original_data')):
os.mkdir(os.path.join(root_dir, 'original_data'))
os.mkdir(os.path.join(root_dir, 'original_data', 'train'))
os.mkdir(os.path.join(root_dir, 'original_data', 'test'))
files_links = {'agnews':'https://drive.google.com/file/d/0Bz8a_Dbh9QhbUDNpeUdjb0wxRms/view?usp=sharing',
'amazon':'https://drive.google.com/file/d/0Bz8a_Dbh9QhbZVhsUnRWRDhETzA/view?usp=sharing',
'dbpedia':'https://drive.google.com/file/d/0Bz8a_Dbh9QhbQ2Vic1kxMmZZQ1k/view?usp=sharing',
'yelp':'https://drive.google.com/file/d/0Bz8a_Dbh9QhbZlU4dXhHTFhZQU0/view?usp=sharing',
'yahoo':'https://drive.google.com/file/d/0Bz8a_Dbh9Qhbd2JNdDBsQUdocVU/view?usp=sharing'}
# taken from this StackOverflow answer: https://stackoverflow.com/a/39225039
def download_file_from_google_drive(id, destination):
URL = "https://docs.google.com/uc?export=download"
session = requests.Session()
response = session.get(URL, params = { 'id' : id }, stream = True)
token = get_confirm_token(response)
if token:
params = { 'id' : id, 'confirm' : token }
response = session.get(URL, params = params, stream = True)
save_response_content(response, destination)
def get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None
def save_response_content(response, destination):
CHUNK_SIZE = 32768
with open(destination, "wb") as f:
for chunk in response.iter_content(CHUNK_SIZE):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
# download tarballs
for name,url in files_links.items():
output = os.path.join(root_dir, name+'.tar.gz')
file_id = url.split('/')[5]
download_file_from_google_drive(file_id, output)
# extract to train and test
for name in files_links.keys():
with tarfile.open(os.path.join(root_dir,name+'.tar.gz')) as tf:
for m in tf.getmembers():
if m.name[-4:] == '.csv':
folder = m.name.split('.')[0].split('/')[-1]
m.name = name+'.csv'
print(f'Extracting {m.name} {folder}')
tf.extract(m, path=os.path.join(root_dir, 'original_data', folder))
# delete tarballs
tars = glob(os.path.join(root_dir,'*.gz'))
[os.remove(t) for t in tars];