Skip to content

Commit

Permalink
Fix error for tabgenie --help and minor refactoring
Browse files Browse the repository at this point in the history
Refactoring towards faster flask cli

flask CLI could be really slow if
each command imports all the libraries also
from the other commands.

Solution: each command should import only its dependencies

This commit tried to refactored in this direction.
But it seems that we need create_app on most of the commands
For these commands it did not speed up the CLI
  • Loading branch information
oplatek committed Jan 25, 2023
1 parent 664258e commit 7b3ed90
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 41 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[tool.black]
line-length = 120
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
extras_require={
"dev": [
"wheel",
"black",
],
"deploy": [
"gunicorn",
Expand Down
56 changes: 54 additions & 2 deletions src/tabgenie/cli.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,54 @@
#!/usr/bin/env python3
import click
import os
import logging
from flask.cli import FlaskGroup, with_appcontext, pass_script_info

from .main import create_app, export_dataset
logger = logging.getLogger(__name__)


def create_app(**kwargs):
from click import get_current_context
import yaml

ctx = get_current_context(silent=True)
if ctx and hasattr(ctx.obj, "disable_pipelines"):
disable_pipelines = ctx.obj.disable_pipelines
else:
# Production server, e.g., gunincorn
# We don't have access to the current context, so must read kwargs instead.
disable_pipelines = kwargs.get("disable_pipelines", False)

with open("config.yml") as f:
config = yaml.safe_load(f)

# Imports from main slow down flask CLI
# since main have very time-consuming libraries to import
from .main import app, initialize_pipeline, load_prompts

app.config.update(config)
app.config["root_dir"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir, os.pardir)
app.config["datasets_obj"] = {}
app.config["pipelines_obj"] = {}
app.config["prompts"] = load_prompts()

if app.config.get("pipelines") and not disable_pipelines:
for pipeline_name in app.config["pipelines"].keys():
initialize_pipeline(pipeline_name)
else:
app.config["pipelines"] = {}

# preload
if config["cache_dev_splits"]:
for dataset_name in app.config["datasets"]:
get_dataset(dataset_name, "dev")

if config["debug"] is False:
logging.getLogger("werkzeug").disabled = True

logger.info("Application ready")

return app


@click.group(cls=FlaskGroup, create_app=create_app)
Expand Down Expand Up @@ -31,6 +77,12 @@ def run(script_info, disable_pipelines):
)
@with_appcontext
def export(dataset, split, out_dir, export_format, json_template):
from .main import export_dataset

export_dataset(
dataset_name=dataset, split=split, out_dir=out_dir, export_format=export_format, json_template=json_template
dataset_name=dataset,
split=split,
out_dir=out_dir,
export_format=export_format,
json_template=json_template,
)
39 changes: 0 additions & 39 deletions src/tabgenie/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@
import logging
import linecache

import yaml
import coloredlogs
import pandas as pd
from flask import Flask, render_template, jsonify, request, send_file
from click import get_current_context

from .loaders import DATASET_CLASSES
from .processing.processing import get_pipeline_class_by_name
Expand Down Expand Up @@ -282,40 +280,3 @@ def index():
default_dataset=app.config["default_dataset"],
host_prefix=app.config["host_prefix"],
)


def create_app(**kwargs):
ctx = get_current_context(silent=True)
if ctx:
disable_pipelines = ctx.obj.disable_pipelines
else:
# Production server, e.g., gunincorn
# We don't have access to the current context, so must read kwargs instead.
disable_pipelines = kwargs.get("disable_pipelines", False)

with open("config.yml") as f:
config = yaml.safe_load(f)

app.config.update(config)
app.config["root_dir"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir, os.pardir)
app.config["datasets_obj"] = {}
app.config["pipelines_obj"] = {}
app.config["prompts"] = load_prompts()

if app.config.get("pipelines") and not disable_pipelines:
for pipeline_name in app.config["pipelines"].keys():
initialize_pipeline(pipeline_name)
else:
app.config["pipelines"] = {}

# preload
if config["cache_dev_splits"]:
for dataset_name in app.config["datasets"]:
get_dataset(dataset_name, "dev")

if config["debug"] is False:
logging.getLogger("werkzeug").disabled = True

logger.info("Application ready")

return app

0 comments on commit 7b3ed90

Please sign in to comment.