From 7b3ed9054d2827849d306d99ba58b65b94f1fd60 Mon Sep 17 00:00:00 2001 From: Ondrej Platek Date: Wed, 25 Jan 2023 11:32:43 +0100 Subject: [PATCH] Fix error for tabgenie --help and minor refactoring Refactoring towards faster flask cli flask CLI could be really slow if each command imports all the libraries also from the other commands. Solution: each command should import only its dependencies This commit tried to refactored in this direction. But it seems that we need create_app on most of the commands For these commands it did not speed up the CLI --- pyproject.toml | 2 ++ setup.py | 1 + src/tabgenie/cli.py | 56 ++++++++++++++++++++++++++++++++++++++++++-- src/tabgenie/main.py | 39 ------------------------------ 4 files changed, 57 insertions(+), 41 deletions(-) create mode 100644 pyproject.toml diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..55ec8d7 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,2 @@ +[tool.black] +line-length = 120 diff --git a/setup.py b/setup.py index dc9d771..88f8353 100644 --- a/setup.py +++ b/setup.py @@ -40,6 +40,7 @@ extras_require={ "dev": [ "wheel", + "black", ], "deploy": [ "gunicorn", diff --git a/src/tabgenie/cli.py b/src/tabgenie/cli.py index cdf6bf5..91ab11d 100755 --- a/src/tabgenie/cli.py +++ b/src/tabgenie/cli.py @@ -1,8 +1,54 @@ #!/usr/bin/env python3 import click +import os +import logging from flask.cli import FlaskGroup, with_appcontext, pass_script_info -from .main import create_app, export_dataset +logger = logging.getLogger(__name__) + + +def create_app(**kwargs): + from click import get_current_context + import yaml + + ctx = get_current_context(silent=True) + if ctx and hasattr(ctx.obj, "disable_pipelines"): + disable_pipelines = ctx.obj.disable_pipelines + else: + # Production server, e.g., gunincorn + # We don't have access to the current context, so must read kwargs instead. + disable_pipelines = kwargs.get("disable_pipelines", False) + + with open("config.yml") as f: + config = yaml.safe_load(f) + + # Imports from main slow down flask CLI + # since main have very time-consuming libraries to import + from .main import app, initialize_pipeline, load_prompts + + app.config.update(config) + app.config["root_dir"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir, os.pardir) + app.config["datasets_obj"] = {} + app.config["pipelines_obj"] = {} + app.config["prompts"] = load_prompts() + + if app.config.get("pipelines") and not disable_pipelines: + for pipeline_name in app.config["pipelines"].keys(): + initialize_pipeline(pipeline_name) + else: + app.config["pipelines"] = {} + + # preload + if config["cache_dev_splits"]: + for dataset_name in app.config["datasets"]: + get_dataset(dataset_name, "dev") + + if config["debug"] is False: + logging.getLogger("werkzeug").disabled = True + + logger.info("Application ready") + + return app @click.group(cls=FlaskGroup, create_app=create_app) @@ -31,6 +77,12 @@ def run(script_info, disable_pipelines): ) @with_appcontext def export(dataset, split, out_dir, export_format, json_template): + from .main import export_dataset + export_dataset( - dataset_name=dataset, split=split, out_dir=out_dir, export_format=export_format, json_template=json_template + dataset_name=dataset, + split=split, + out_dir=out_dir, + export_format=export_format, + json_template=json_template, ) diff --git a/src/tabgenie/main.py b/src/tabgenie/main.py index 66a68f5..0681e92 100755 --- a/src/tabgenie/main.py +++ b/src/tabgenie/main.py @@ -6,11 +6,9 @@ import logging import linecache -import yaml import coloredlogs import pandas as pd from flask import Flask, render_template, jsonify, request, send_file -from click import get_current_context from .loaders import DATASET_CLASSES from .processing.processing import get_pipeline_class_by_name @@ -282,40 +280,3 @@ def index(): default_dataset=app.config["default_dataset"], host_prefix=app.config["host_prefix"], ) - - -def create_app(**kwargs): - ctx = get_current_context(silent=True) - if ctx: - disable_pipelines = ctx.obj.disable_pipelines - else: - # Production server, e.g., gunincorn - # We don't have access to the current context, so must read kwargs instead. - disable_pipelines = kwargs.get("disable_pipelines", False) - - with open("config.yml") as f: - config = yaml.safe_load(f) - - app.config.update(config) - app.config["root_dir"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir, os.pardir) - app.config["datasets_obj"] = {} - app.config["pipelines_obj"] = {} - app.config["prompts"] = load_prompts() - - if app.config.get("pipelines") and not disable_pipelines: - for pipeline_name in app.config["pipelines"].keys(): - initialize_pipeline(pipeline_name) - else: - app.config["pipelines"] = {} - - # preload - if config["cache_dev_splits"]: - for dataset_name in app.config["datasets"]: - get_dataset(dataset_name, "dev") - - if config["debug"] is False: - logging.getLogger("werkzeug").disabled = True - - logger.info("Application ready") - - return app