From 0705cc2215cfdc92adc55df3d47afb40308af98a Mon Sep 17 00:00:00 2001 From: Henry Kironde Date: Mon, 24 Aug 2020 07:51:08 -0400 Subject: [PATCH] Update and clean kaggle support pr Co-authored-by: dumbmachine --- Dockerfile | 1 + requirements.txt | 1 + retriever/lib/defaults.py | 1 + retriever/lib/engine.py | 121 +++++++++++++++++++++++++++++--------- test/test_retriever.py | 28 ++++++++- 5 files changed, 122 insertions(+), 30 deletions(-) diff --git a/Dockerfile b/Dockerfile index f179936d8..2de7ce9bc 100644 --- a/Dockerfile +++ b/Dockerfile @@ -39,6 +39,7 @@ RUN pip install pylint RUN pip install flake8 -U RUN pip install h5py RUN pip install Pillow +RUN pip install kaggle # Install Postgis after Python is setup RUN apt-get install -y --force-yes postgis diff --git a/requirements.txt b/requirements.txt index 969006b83..e1ebb27fc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,7 @@ future xlrd>=0.7 argcomplete +kaggle PyMySQL>=0.4 psycopg2-binary numpydoc diff --git a/retriever/lib/defaults.py b/retriever/lib/defaults.py index 44ea1b47e..927b0f814 100644 --- a/retriever/lib/defaults.py +++ b/retriever/lib/defaults.py @@ -13,6 +13,7 @@ RETRIEVER_REPOSITORY = RETRIEVER_MASTER_BRANCH ENCODING = 'utf-8' HOME_DIR = os.path.expanduser('~/.retriever/') +KAGGLE_TOKEN_PATH = os.path.expanduser('~/.kaggle/kaggle.json') RETRIEVER_DIR = 'retriever' if os.path.exists(os.path.join(HOME_DIR, 'retriever_path.txt')): with open(os.path.join(HOME_DIR, 'retriever_path.txt'), 'r') as f: diff --git a/retriever/lib/engine.py b/retriever/lib/engine.py index 1cfcf65a3..116fc435d 100644 --- a/retriever/lib/engine.py +++ b/retriever/lib/engine.py @@ -18,7 +18,7 @@ from tqdm import tqdm from retriever.lib.cleanup import no_cleanup -from retriever.lib.defaults import DATA_DIR, DATA_SEARCH_PATHS, DATA_WRITE_PATH, ENCODING +from retriever.lib.defaults import DATA_DIR, DATA_SEARCH_PATHS, DATA_WRITE_PATH, ENCODING, KAGGLE_TOKEN_PATH from retriever.lib.tools import ( open_fr, open_fw, @@ -283,8 +283,8 @@ def auto_get_datatypes(self, pk, source, columns): if column_types[i][0] == 'double': try: val = float(val) - if "e" in str(val) or \ - ("." in str(val) and len(str(val).split(".")[1]) > 10): + if "e" in str(val) or ("." in str(val) and len( + str(val).split(".")[1]) > 10): column_types[i] = ("decimal", "50,30") except Exception as _: column_types[i] = ('char', max_lengths[i]) @@ -509,6 +509,62 @@ def download_file(self, url, filename): progbar.close() return True + def download_from_kaggle( + self, + data_source, + dataset_name, + archive_dir, + archive_full_path, + ): + """Download files from Kaggle into the raw data directory""" + kaggle_token = os.path.isfile(KAGGLE_TOKEN_PATH) + kaggle_username = os.getenv('KAGGLE_USERNAME', "").strip() + kaggle_key = os.getenv('KAGGLE_KEY', "").strip() + + if kaggle_token or (kaggle_username and kaggle_key): + from kaggle.api.kaggle_api_extended import KaggleApi + from kaggle.rest import ApiException + else: + print(f"Could not find kaggle.json. Make sure it's located at " + f"{KAGGLE_TOKEN_PATH}. Or available in the environment variables. " + f"For more information " + f"checkout https://github.com/Kaggle/kaggle-api#api-credentials") + return + + api = KaggleApi() + api.authenticate() + + if data_source == "dataset": + archive_full_path = archive_full_path + ".zip" + try: + api.dataset_download_files(dataset=dataset_name, + path=archive_dir, + quiet=False, + force=True) + file_names = self.extract_zip(archive_full_path, archive_dir) + except ApiException: + print(f"The dataset '{dataset_name}' isn't currently available " + f"in the Retriever.\nRun 'retriever ls' to see a " + f"list of currently available datasets.") + return [] + + else: + archive_full_path = archive_full_path.replace("kaggle:competition:", + "") + ".zip" + try: + api.competition_download_files(competition=dataset_name, + path=archive_dir, + quiet=False, + force=True) + file_names = self.extract_zip(archive_full_path, archive_dir) + except ApiException: + print(f"The dataset '{dataset_name}' isn't currently available " + f"in the Retriever.\nRun 'retriever ls' to see a " + f"list of currently available datasets.") + return [] + + return file_names + def download_files_from_archive( self, url, @@ -534,34 +590,41 @@ def download_files_from_archive( if not os.path.exists(archive_dir): os.makedirs(archive_dir) - if not file_names: - self.download_file(url, archive_name) - if archive_type in ('tar', 'tar.gz'): - file_names = self.extract_tar(archive_full_path, archive_dir, - archive_type) - elif archive_type == 'zip': - file_names = self.extract_zip(archive_full_path, archive_dir) - elif archive_type == 'gz': - file_names = self.extract_gz(archive_full_path, archive_dir) - return file_names + if hasattr(self.script.__dict__, "kaggle"): + file_names = self.download_from_kaggle(data_source=self.script.data_source, + dataset_name=url, + archive_dir=archive_dir, + archive_full_path=archive_full_path) - archive_downloaded = bool(self.data_path) - for file_name in file_names: - archive_full_path = self.format_filename(archive_name) - if not self.find_file(os.path.join(archive_dir, file_name)): - # if no local copy, download the data - self.create_raw_data_dir() - if not archive_downloaded: - self.download_file(url, archive_name) - archive_downloaded = True - if archive_type == 'zip': - self.extract_zip(archive_full_path, archive_dir, file_name) + else: + if not file_names: + self.download_file(url, archive_name) + if archive_type in ('tar', 'tar.gz'): + file_names = self.extract_tar(archive_full_path, archive_dir, + archive_type) + elif archive_type == 'zip': + file_names = self.extract_zip(archive_full_path, archive_dir) elif archive_type == 'gz': - self.extract_gz(archive_full_path, archive_dir, file_name) - elif archive_type in ('tar', 'tar.gz'): - self.extract_tar(archive_full_path, archive_dir, archive_type, - file_name) - return file_names + file_names = self.extract_gz(archive_full_path, archive_dir) + return file_names + + archive_downloaded = bool(self.data_path) + for file_name in file_names: + archive_full_path = self.format_filename(archive_name) + if not self.find_file(os.path.join(archive_dir, file_name)): + # if no local copy, download the data + self.create_raw_data_dir() + if not archive_downloaded: + self.download_file(url, archive_name) + archive_downloaded = True + if archive_type == 'zip': + self.extract_zip(archive_full_path, archive_dir, file_name) + elif archive_type == 'gz': + self.extract_gz(archive_full_path, archive_dir, file_name) + elif archive_type in ('tar', 'tar.gz'): + self.extract_tar(archive_full_path, archive_dir, archive_type, + file_name) + return file_names def drop_statement(self, object_type, object_name): """Return drop table or database SQL statement.""" diff --git a/test/test_retriever.py b/test/test_retriever.py index b0139feaf..3333f064e 100644 --- a/test/test_retriever.py +++ b/test/test_retriever.py @@ -26,7 +26,7 @@ from retriever.lib.engine_tools import create_file from retriever.lib.engine_tools import file_2list from retriever.lib.datapackage import clean_input, is_empty -from retriever.lib.defaults import HOME_DIR, RETRIEVER_DATASETS, RETRIEVER_REPOSITORY +from retriever.lib.defaults import HOME_DIR, RETRIEVER_DATASETS, RETRIEVER_REPOSITORY, KAGGLE_TOKEN_PATH # Create simple engine fixture test_engine = Engine() @@ -80,6 +80,12 @@ tar_gz_url = os.path.normpath(achive_url.format(file_path='sample_tar.tar.gz')) gz_url = os.path.normpath(achive_url.format(file_path='sample.gz')) +kaggle_datasets = [ + # test_name, data_source, dataset_identifier, dataset_name, repath, expected + ("kaggle_competition", "competition", "titanic", "titanic", ["gender_submission.csv", "test.csv", "train.csv"]), + ("kaggle_unknown", "dataset", "uciml/iris", "iris", ['Iris.csv', 'database.sqlite']), + ("kaggle_dataset", "competition", "non_existent_dataset", "non_existent_dataset", []), +] def setup_module(): """"Automatically sets up the environment before the module runs. @@ -244,6 +250,26 @@ def test_drop_statement(): 'TABLE', 'tablename') == "DROP TABLE IF EXISTS tablename" +@pytest.mark.parametrize("test_name, data_source, dataset_identifier, repath, expected", kaggle_datasets) +def test_download_kaggle_dataset(test_name, data_source, dataset_identifier, repath, expected): + """Test the downloading of dataset from kaggle.""" + setup_functions() + files = test_engine.download_from_kaggle( + data_source=data_source, + dataset_name=dataset_identifier, + archive_dir=raw_dir_files, + archive_full_path=os.path.join(raw_dir_files, repath) + ) + + kaggle_token = os.path.isfile(KAGGLE_TOKEN_PATH) + kaggle_username = os.getenv('KAGGLE_USERNAME', "").strip() + kaggle_key = os.getenv('KAGGLE_KEY', "").strip() + if kaggle_token or (kaggle_username and kaggle_key): + assert files == expected + else: + assert files == None + + def test_download_archive_gz_known(): """Download and extract known files