diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e5c0f3f..4fdeed2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,11 +1,7 @@ repos: - - repo: https://github.com/astral-sh/ruff-pre-commit - # Ruff version. - rev: v0.0.280 - hooks: - - id: ruff - - repo: https://github.com/psf/black - rev: 23.7.0 - hooks: - - id: black - language_version: python3 +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.1.8 + hooks: + - id: ruff + args: [ --fix, --exit-non-zero-on-fix ] + - id: ruff-format \ No newline at end of file diff --git a/README.md b/README.md index 21f0f70..9ece4b0 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,46 @@ -# OpenHexa Python SDK +
+ OpenHEXA Logo +
+

+ Open-source Data integration platform +

+

+ + Test Suite + +

-The OpenHexa Python SDK is a tool that helps you write code for the OpenHexa platform. +OpenHEXA Python SDK +=================== -It is particularly useful to write OpenHexa data pipelines, but can also be used in the OpenHexa notebooks environment. +OpenHEXA is an open-source data integration platform developed by [Bluesquare](https://bluesquarehub.com). -## Quickstart +Its goal is to facilitate data integration and analysis workflows, in particular in the context of public health +projects. -### Writing and deploying pipelines +Please refer to the [OpenHEXA wiki](https://github.com/BLSQ/openhexa/wiki/Home) for more information about OpenHEXA. + +This repository contains the code of the OpenHEXA SDK, a library allows you to write code for the OpenHEXA platform. +It is particularly useful to write OpenHEXA data pipelines, but can also be used in the OpenHEXA notebooks environment. + +The OpenHEXA wiki has a section dedicated to the SDK: +[Using the OpenHEXA SDK](https://github.com/BLSQ/openhexa/wiki/Using-the-OpenHEXA-SDK). + +For more information about the technical aspects of OpenHEXA, you might be interested in the two following wiki pages: + +- [Installing OpenHEXA](https://github.com/BLSQ/openhexa/wiki/Installation-instructions) +- [Technical Overview](https://github.com/BLSQ/openhexa/wiki/Technical-overview) + +Requirements +------------ + +The OpenHEXA SDK requires Python version 3.9 or newer, but it is not yet compatible with Python 3.12 or later versions. + +If you want to be able to run pipeline in a containerized environment on your machine, you will need +[Docker](https://www.docker.com/). + +Quickstart +---------- Here's a super minimal example to get you started. First, create a new directory and a virtual environment: @@ -17,12 +51,15 @@ python -m venv venv source venv/bin/activate ``` -You can then install the OpenHexa SDK: +You can then install the OpenHEXA SDK: ```shell pip install --upgrade openhexa.sdk ``` +💡New OpenHEXA SDK versions are released on a regular basis. Don't forget to update your local installations with +`pip install --upgrade` from times to times! + Now that the SDK is installed withing your virtual environmentYou can now use the `openhexa` CLI utility to create a new pipeline: @@ -30,8 +67,8 @@ a new pipeline: openhexa pipelines init "My awesome pipeline" ``` -Great! As you can see in the console output, the OpenHexa CLI has created a new directory, which contains the basic -structure required for an OpenHexa pipeline. You can now `cd` in the new pipeline directory and run the pipeline: +Great! As you can see in the console output, the OpenHEXA CLI has created a new directory, which contains the basic +structure required for an OpenHEXA pipeline. You can now `cd` in the new pipeline directory and run the pipeline: ```shell cd my_awesome_pipeline @@ -41,11 +78,11 @@ python pipeline.py Congratulations! You have successfully run your first pipeline locally. If you inspect the actual pipeline code, you will see that it doesn't do a lot of things, but it is still a perfectly -valid OpenHexa pipeline. +valid OpenHEXA pipeline. -Let's publish to an actual OpenHexa workspace so that it can run online. +Let's publish to an actual OpenHEXA workspace so that it can run online. -Using the OpenHexa web interface, within a workspace, navigate to the Pipelines tab and click on "Create". +Using the OpenHEXA web interface, within a workspace, navigate to the Pipelines tab and click on "Create". Copy the command displayed in the popup in your terminal: @@ -62,20 +99,17 @@ openhexa pipelines push ``` As it is the first time, the CLI will ask you to confirm the creation operation. After confirmation the console will -output the link to the pipeline screen in the OpenHexa interface. - -You can now open the link and run the pipeline using the OpenHexa web interface. +output the link to the pipeline screen in the OpenHEXA interface. -### Using the SDK in the notebooks environment +You can now open the link and run the pipeline using the OpenHEXA web interface. -TBC +Contributing +------------ -## Contributing +The following sections explain how you can set up a local development environment if you want to participate to the +development of the SDK. -The following sections explain how you can setup a local development environment if you want to participate to the -development of the SDK - -### Development setup +### SDK development setup Install the SDK in editable mode: @@ -85,7 +119,13 @@ source venv/bin/activate # Activate the venv pip install -e ".[dev]" # Necessary to be able to run the openhexa CLI ``` -### Using a local installation of the OpenHexa backend to run pipelines +### Using a local installation of OpenHEXA to run pipelines + +While it is possible to run pipelines locally using only the SDK, if you want to run OpenHEXA in a more realistic +setting you will need to install the OpenHEXA app and frontend components. Please refer to the +[installation instructions](https://github.com/BLSQ/openhexa/wiki/Installation-instructions) for more information. + +You can then configure the OpenHEXA CLI to connect to your local backend: ```shell openhexa config set_url http://localhost:8000 @@ -95,7 +135,7 @@ Notes: you can monitor the status of your pipelines using http://localhost:8000/ ### Running the tests -Run the tests using pytest: +You can run the tests using pytest: ```shell pytest diff --git a/examples/pipelines/logistic_stats/pipeline.py b/examples/pipelines/logistic_stats/pipeline.py index 513823f..0933a69 100644 --- a/examples/pipelines/logistic_stats/pipeline.py +++ b/examples/pipelines/logistic_stats/pipeline.py @@ -1,3 +1,5 @@ +"""Simple module for a sample logistic pipeline.""" + import json import typing from io import BytesIO @@ -26,6 +28,7 @@ ) @parameter("oul", name="Organisation unit level", type=int, default=2) def logistic_stats(deg: str, periods: str, oul: int): + """Run a basic logistic stats pipeline.""" dhis2_data = dhis2_download(deg, periods, oul) gadm_data = gadm_download() worldpop_data = worldpop_download() @@ -34,7 +37,8 @@ def logistic_stats(deg: str, periods: str, oul: int): @logistic_stats.task -def dhis2_download(data_element_group: str, periods: str, org_unit_level: int) -> typing.Dict[str, typing.Any]: +def dhis2_download(data_element_group: str, periods: str, org_unit_level: int) -> dict[str, typing.Any]: + """Download DHIS2 data.""" connection = workspace.dhis2_connection("dhis2-play") base_url = f"{connection.url}/api" session = requests.Session() @@ -61,6 +65,7 @@ def dhis2_download(data_element_group: str, periods: str, org_unit_level: int) - @logistic_stats.task def gadm_download(): + """Download administrative boundaries data from UCDavis.""" url = "https://geodata.ucdavis.edu/gadm/gadm4.1/gpkg/gadm41_SLE.gpkg" r = requests.get(url, timeout=30) @@ -69,6 +74,7 @@ def gadm_download(): @logistic_stats.task def worldpop_download(): + """Download population data from worldpop.org.""" base_url = "https://data.worldpop.org/" url = f"{base_url}GIS/Population/Global_2000_2020_Constrained/2020/maxar_v1/SLE/sle_ppp_2020_UNadj_constrained.tif" r = requests.get(url) @@ -77,7 +83,8 @@ def worldpop_download(): @logistic_stats.task -def model(dhis2_data: typing.Dict[str, typing.Any], gadm_data, worldpop_data): +def model(dhis2_data: dict[str, typing.Any], gadm_data, worldpop_data): + """Build a basic data model.""" # Load DHIS2 data dhis2_df = pd.DataFrame(dhis2_data["rows"], columns=[h["column"] for h in dhis2_data["headers"]]) dhis2_df = dhis2_df.rename(columns={"Data": "Data element id", "Organisation unit": "Organisation unit id"}) diff --git a/examples/pipelines/logistic_stats/tests/__init__.py b/examples/pipelines/logistic_stats/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/examples/pipelines/simple_io/pipeline.py b/examples/pipelines/simple_io/pipeline.py index f847316..8a52760 100644 --- a/examples/pipelines/simple_io/pipeline.py +++ b/examples/pipelines/simple_io/pipeline.py @@ -1,3 +1,5 @@ +"""Simple module for a sample IO pipeline.""" + import json import pandas as pd @@ -9,6 +11,7 @@ @pipeline("simple-io", name="Simple IO") def simple_io(): + """Run a simple IO pipeline.""" # Read and write from/to workspace files raw_files_data = load_files_data() transform_and_write_files_data(raw_files_data) @@ -23,6 +26,7 @@ def simple_io(): @simple_io.task def load_files_data(): + """Load data from workspace filesystem.""" current_run.log_info("Loading files data...") return pd.read_csv(f"{workspace.files_path}/raw.csv") @@ -30,6 +34,7 @@ def load_files_data(): @simple_io.task def transform_and_write_files_data(raw_data: pd.DataFrame): + """Simulate a transformation on the provided dataframe and write data to workspace filesystem.""" current_run.log_info("Transforming files data...") transformed_data = raw_data.copy() @@ -41,6 +46,7 @@ def transform_and_write_files_data(raw_data: pd.DataFrame): @simple_io.task def load_data_from_postgresql() -> pd.DataFrame: + """Perform a simple SELECT query in the workspace database.""" current_run.log_info("Loading Postgres data...") engine = create_engine(workspace.database_url) @@ -50,6 +56,7 @@ def load_data_from_postgresql() -> pd.DataFrame: @simple_io.task def transform_and_write_sql_data(raw_data: pd.DataFrame): + """Simulate a transform operation on the provided data and load it in the workspace database.""" current_run.log_info("Transforming postgres data...") engine = create_engine(workspace.database_url) @@ -61,6 +68,7 @@ def transform_and_write_sql_data(raw_data: pd.DataFrame): @simple_io.task def load_dhis2_data(): + """Load data from DHIS2.""" current_run.log_info("Loading DHIS2 data...") connection = workspace.dhis2_connection("dhis2-play") diff --git a/openhexa/cli/__init__.py b/openhexa/cli/__init__.py index 7272735..eeb05c4 100644 --- a/openhexa/cli/__init__.py +++ b/openhexa/cli/__init__.py @@ -1,3 +1,5 @@ +"""CLI package.""" + from .cli import app __all__ = ["app"] diff --git a/openhexa/cli/api.py b/openhexa/cli/api.py index e9b0618..4c46298 100644 --- a/openhexa/cli/api.py +++ b/openhexa/cli/api.py @@ -1,8 +1,11 @@ +"""Collection of functions that interacts with the OpenHEXA API.""" + import base64 -import configparser import enum import io import os +import typing +from configparser import ConfigParser from importlib.metadata import version from pathlib import Path from zipfile import ZipFile @@ -16,20 +19,29 @@ class InvalidDefinitionError(Exception): + """Raised whenever pipeline parameters and/or pipeline options are incompatible.""" + pass -class PipelineErrorEnum(enum.Enum): +class PipelineDefinitionErrorCode(enum.Enum): + """Enumeration of possible pipeline definition error codes.""" + PIPELINE_DOES_NOT_SUPPORT_PARAMETERS = "PIPELINE_DOES_NOT_SUPPORT_PARAMETERS" INVALID_TIMEOUT_VALUE = "INVALID_TIMEOUT_VALUE" -def is_debug(config: configparser.ConfigParser): +def is_debug(config: ConfigParser) -> bool: + """Determine whether the provided configuration has the debug flag.""" return config.getboolean("openhexa", "debug", fallback=False) def open_config(): - config = configparser.ConfigParser() + """Open the local configuration file using configparser. + + A default configuration file will be generated if the file does not exist. + """ + config = ConfigParser() if os.path.exists(CONFIGFILE_PATH): config.read(CONFIGFILE_PATH) else: @@ -44,12 +56,14 @@ def open_config(): return config -def save_config(config): +def save_config(config: ConfigParser): + """Save the provided configparser local configuration to disk.""" with open(CONFIGFILE_PATH, "w") as configfile: config.write(configfile) def graphql(config, query: str, variables=None, token=None): + """Perform a GraphQL request.""" url = config["openhexa"]["url"] + "/graphql/" if token is None: current_workspace = config["openhexa"]["current_workspace"] @@ -89,6 +103,7 @@ def graphql(config, query: str, variables=None, token=None): def get_workspace(config, slug: str, token: str): + """Get a single workspace.""" return graphql( config, """ @@ -104,7 +119,8 @@ def get_workspace(config, slug: str, token: str): )["workspace"] -def get_pipelines(config): +def list_pipelines(config): + """List all pipelines in the workspace.""" data = graphql( config, """ @@ -126,7 +142,8 @@ def get_pipelines(config): return data["pipelines"]["items"] -def get_pipeline(config, pipeline_code: str): +def get_pipeline(config, pipeline_code: str) -> dict[str, typing.Any]: + """Get a single pipeline.""" data = graphql( config, """ @@ -149,6 +166,7 @@ def get_pipeline(config, pipeline_code: str): def create_pipeline(config, pipeline_code: str, pipeline_name: str): + """Create a pipeline using the API.""" data = graphql( config, """ @@ -179,7 +197,8 @@ def create_pipeline(config, pipeline_code: str, pipeline_name: str): return data["createPipeline"]["pipeline"] -def delete_pipeline(config, id: str): +def delete_pipeline(config, pipeline_id: str): + """Delete a single pipeline.""" data = graphql( config, """ @@ -190,7 +209,7 @@ def delete_pipeline(config, id: str): } } """, - {"input": {"id": id}}, + {"input": {"id": pipeline_id}}, ) if not data["deletePipeline"]["success"]: @@ -200,7 +219,7 @@ def delete_pipeline(config, id: str): def ensure_is_pipeline_dir(pipeline_path: str): - # Ensure that there is a pipeline.py file in the directory + """Ensure that there is a pipeline.py file in the directory.""" if not os.path.isdir(pipeline_path): raise Exception(f"Path {pipeline_path} is not a directory") if not os.path.exists(pipeline_path): @@ -211,11 +230,15 @@ def ensure_is_pipeline_dir(pipeline_path: str): return True -def upload_pipeline(config, pipeline_directory_path: str): +def upload_pipeline(config, pipeline_directory_path: typing.Union[str, Path]): + """Upload the pipeline contained in the provided directory using the GraphQL API. + + The pipeline code will be zipped and base64-encoded before being sent to the backend. + """ pipeline = import_pipeline(pipeline_directory_path) directory = Path(os.path.abspath(pipeline_directory_path)) - zipFile = io.BytesIO(b"") + zip_file = io.BytesIO(b"") if is_debug(config): click.echo("Generating ZIP file:") @@ -228,7 +251,7 @@ def upload_pipeline(config, pipeline_directory_path: str): if env_vars.get("WORKSPACE_FILES_PATH") and Path(env_vars["WORKSPACE_FILES_PATH"]) not in excluded_paths: excluded_paths.append(Path(env_vars["WORKSPACE_FILES_PATH"])) - with ZipFile(zipFile, "w") as zipObj: + with ZipFile(zip_file, "w") as zipObj: for path in directory.glob("**/*"): if path.name == "python": # We are in a virtual environment @@ -252,15 +275,15 @@ def upload_pipeline(config, pipeline_directory_path: str): click.echo(f"\t{file_path.name}") zipObj.write(file_path, file_path.relative_to(directory)) - zipFile.seek(0) + zip_file.seek(0) if is_debug(config): - # Write zipFile to disk for debugging + # Write zip_file to disk for debugging with open("pipeline.zip", "wb") as debug_file: - debug_file.write(zipFile.read()) - zipFile.seek(0) + debug_file.write(zip_file.read()) + zip_file.seek(0) - base64_content = base64.b64encode(zipFile.read()).decode("ascii") + base64_content = base64.b64encode(zip_file.read()).decode("ascii") data = graphql( config, @@ -285,11 +308,12 @@ def upload_pipeline(config, pipeline_directory_path: str): ) if not data["uploadPipeline"]["success"]: - if PipelineErrorEnum.PIPELINE_DOES_NOT_SUPPORT_PARAMETERS.value in data["uploadPipeline"]["errors"]: + if PipelineDefinitionErrorCode.PIPELINE_DOES_NOT_SUPPORT_PARAMETERS.value in data["uploadPipeline"]["errors"]: raise InvalidDefinitionError( - "Cannot push a new version : this pipeline has a schedule and the new version is not schedulable (all parameters must be optional or have default values)." + "Cannot push a new version : this pipeline has a schedule and the new version cannot be scheduled " + "(all parameters must be optional or have default values)." ) - elif PipelineErrorEnum.INVALID_TIMEOUT_VALUE.value in data["uploadPipeline"]["errors"]: + elif PipelineDefinitionErrorCode.INVALID_TIMEOUT_VALUE.value in data["uploadPipeline"]["errors"]: raise InvalidDefinitionError( "Timeout value is invalid : ensure that it's no negative and inferior to 12 hours." ) diff --git a/openhexa/cli/cli.py b/openhexa/cli/cli.py index 67b5131..9464eb3 100644 --- a/openhexa/cli/cli.py +++ b/openhexa/cli/cli.py @@ -1,3 +1,5 @@ +"""CLI module, with click.""" + import base64 import json import sys @@ -13,14 +15,13 @@ delete_pipeline, ensure_is_pipeline_dir, get_pipeline, - get_pipelines, get_workspace, is_debug, + list_pipelines, open_config, save_config, upload_pipeline, ) -from openhexa.cli.utils import terminate from openhexa.sdk.pipelines import get_local_workspace_config, import_pipeline @@ -29,9 +30,7 @@ @click.version_option(version("openhexa.sdk")) @click.pass_context def app(ctx, debug): - """ - OpenHexa CLI - """ + """OpenHEXA CLI.""" # ensure that ctx.obj exists and is a dict (in case `cli()` is called # by means other than the `if` block below) ctx.ensure_object(dict) @@ -44,9 +43,7 @@ def app(ctx, debug): @app.group(invoke_without_command=True) @click.pass_context def workspaces(ctx): - """ - Manage workspaces (add workspace, remove workspace, list workspaces, activate a workspace) - """ + """Manage workspaces (add workspace, remove workspace, list workspaces, activate a workspace).""" if ctx.invoked_subcommand is None: ctx.forward(workspaces_list) @@ -55,9 +52,7 @@ def workspaces(ctx): @click.argument("slug") @click.option("--token", prompt=True, hide_input=True, confirmation_prompt=False) def workspaces_add(slug, token): - """ - Add a workspace to the configuration and activate it. The access token is required to access the workspace. - """ + """Add a workspace to the configuration and activate it. The access token is required to access the workspace.""" user_config = open_config() if slug in user_config["workspaces"]: click.echo(f"Workspace {slug} already exists. We will only update its token.") @@ -83,10 +78,7 @@ def workspaces_add(slug, token): @workspaces.command(name="activate") @click.argument("slug") def workspaces_activate(slug): - """ - Activate a workspace that is already in the configuration. The activated workspace will be used for the 'pipelines' commands. - """ - + """Activate a workspace that is already in the configuration. The activated workspace will be used for the 'pipelines' commands.""" user_config = open_config() if slug not in user_config["workspaces"]: click.echo(f"Workspace {slug} does not exist on {user_config['openhexa']['url']}. Available workspaces:") @@ -100,9 +92,7 @@ def workspaces_activate(slug): @workspaces.command(name="list") def workspaces_list(): - """ - List the workspaces in the configuration. - """ + """List the workspaces in the configuration.""" user_config = open_config() click.echo("Workspaces:") @@ -119,8 +109,7 @@ def workspaces_list(): @workspaces.command(name="rm") @click.argument("slug") def workspaces_rm(slug): - """ - Remove a workspace from the configuration. + """Remove a workspace from the configuration. SLUG is the slug of the workspace to remove from the configuration. """ @@ -143,10 +132,7 @@ def workspaces_rm(slug): @app.group(invoke_without_command=True) @click.pass_context def config(ctx): - """ - Manage configuration of the CLI. - """ - + """Manage configuration of the CLI.""" if ctx.invoked_subcommand is None: user_config = open_config() click.echo("Debug: " + ("True" if is_debug(user_config) else "False")) @@ -162,10 +148,7 @@ def config(ctx): @config.command(name="set_url") @click.argument("url") def config_set_url(url): - """ - Set the URL of the backend. - - """ + """Set the URL of the backend.""" user_config = open_config() user_config["openhexa"].update({"url": url}) save_config(user_config) @@ -175,9 +158,7 @@ def config_set_url(url): @app.group(invoke_without_command=True) @click.pass_context def pipelines(ctx): - """ - Manage pipelines (list pipelines, push a pipeline to the backend) - """ + """Manage pipelines (list pipelines, push a pipeline to the backend).""" if ctx.invoked_subcommand is None: ctx.forward(pipelines_list) @@ -185,10 +166,7 @@ def pipelines(ctx): @pipelines.command("init") @click.argument("name", type=str) def pipelines_init(name: str): - """ - Initialize a new pipeline in a fresh directory. - """ - + """Initialize a new pipeline in a fresh directory.""" new_pipeline_directory_name = stringcase.snakecase(name.lower()) new_pipeline_path = Path.cwd() / Path(new_pipeline_directory_name) if new_pipeline_path.exists(): @@ -201,16 +179,16 @@ def pipelines_init(name: str): # Load samples sample_directory_path = Path(__file__).parent / Path("skeleton") - with open(sample_directory_path / Path(".gitignore"), "r") as sample_ignore_file: + with open(sample_directory_path / Path(".gitignore")) as sample_ignore_file: sample_ignore_content = sample_ignore_file.read() - with open(sample_directory_path / Path("pipeline.py"), "r") as sample_pipeline_file: + with open(sample_directory_path / Path("pipeline.py")) as sample_pipeline_file: sample_pipeline_content = ( sample_pipeline_file.read() .replace("skeleton-pipeline-code", stringcase.spinalcase(name.lower())) .replace("skeleton_pipeline_name", stringcase.snakecase(name.lower())) .replace("Skeleton pipeline name", name) ) - with open(sample_directory_path / Path("workspace.yaml"), "r") as sample_workspace_file: + with open(sample_directory_path / Path("workspace.yaml")) as sample_workspace_file: sample_workspace_content = sample_workspace_file.read() # Create directory @@ -237,12 +215,10 @@ def pipelines_init(name: str): @pipelines.command("push") @click.argument("path", default=".", type=click.Path(exists=True, file_okay=False, dir_okay=True)) def pipelines_push(path: str): - """ - Push a pipeline to the backend. If the pipeline already exists, it will be updated otherwise it will be created. + """Push a pipeline to the backend. If the pipeline already exists, it will be updated otherwise it will be created. PATH is the path to the pipeline file. """ - user_config = open_config() try: workspace = user_config["openhexa"]["current_workspace"] @@ -264,7 +240,7 @@ def pipelines_push(path: str): raise e sys.exit(1) else: - workspace_pipelines = get_pipelines(user_config) + workspace_pipelines = list_pipelines(user_config) if is_debug(user_config): click.echo(workspace_pipelines) @@ -291,17 +267,17 @@ def pipelines_push(path: str): "api", "app" ) click.echo( - f"Done! You can view the pipeline in OpenHexa on {click.style(url, fg='bright_blue', underline=True)}" + f"Done! You can view the pipeline in OpenHEXA on {click.style(url, fg='bright_blue', underline=True)}" ) except InvalidDefinitionError as e: - terminate( + _terminate( f'Pipeline definition is invalid: "{e}"', err=True, exception=e, debug=is_debug(user_config), ) except Exception as e: - terminate( + _terminate( f'Error while importing pipeline: "{e}"', err=True, exception=e, @@ -312,10 +288,7 @@ def pipelines_push(path: str): @pipelines.command("delete") @click.argument("code", type=str) def pipelines_delete(code: str): - """ - Delete a pipeline and all his versions. - """ - + """Delete a pipeline and all his versions.""" user_config = open_config() try: workspace = user_config["openhexa"]["current_workspace"] @@ -351,7 +324,7 @@ def pipelines_delete(code: str): click.echo(f"Pipeline {click.style(code, bold=True)} deleted.") except Exception as e: - terminate( + _terminate( f'Error while deleting pipeline: "{e}"', err=True, exception=e, @@ -376,9 +349,7 @@ def pipelines_run( config_str: str = "{}", config_file: click.File = None, ): - """ - Run a pipeline locally. - """ + """Run a pipeline locally.""" from subprocess import Popen user_config = open_config() @@ -440,9 +411,7 @@ def pipelines_run( @pipelines.command("list") def pipelines_list(): - """ - List all the remote pipelines of the current workspace. - """ + """List all the remote pipelines of the current workspace.""" user_config = open_config() workspace = user_config["openhexa"]["current_workspace"] @@ -450,15 +419,22 @@ def pipelines_list(): click.echo("No workspace activated", err=True) sys.exit(1) - workspace_pipelines = get_pipelines(user_config) + workspace_pipelines = list_pipelines(user_config) if len(workspace_pipelines) == 0: click.echo(f"No pipelines in workspace {workspace}") return click.echo("Pipelines:") for pipeline in workspace_pipelines: - version = pipeline["currentVersion"].get("number") - if version: - version = f"v{version}" + current_version = pipeline["currentVersion"].get("number") + if current_version is not None: + current_version = f"v{current_version}" else: - version = "N/A" - click.echo(f"* {pipeline['code']} - {pipeline['name']} ({version})") + current_version = "N/A" + click.echo(f"* {pipeline['code']} - {pipeline['name']} ({current_version})") + + +def _terminate(message: str, exception: Exception = None, err: bool = False, debug: bool = False): + click.echo(message, err=err) + if debug and exception: + raise exception + sys.exit(1) diff --git a/openhexa/cli/skeleton/pipeline.py b/openhexa/cli/skeleton/pipeline.py index 299de50..7b7156a 100644 --- a/openhexa/cli/skeleton/pipeline.py +++ b/openhexa/cli/skeleton/pipeline.py @@ -1,14 +1,21 @@ +"""Template for newly generated pipelines.""" + from openhexa.sdk import current_run, pipeline @pipeline("skeleton-pipeline-code", name="Skeleton pipeline name") def skeleton_pipeline_name(): + """Write your pipeline orchestration here. + + Pipeline functions should only call tasks and should never perform IO operations or expensive computations. + """ count = task_1() task_2(count) @skeleton_pipeline_name.task def task_1(): + """Put some data processing code here.""" current_run.log_info("In task 1...") return 42 @@ -16,6 +23,7 @@ def task_1(): @skeleton_pipeline_name.task def task_2(count): + """Put some data processing code here.""" current_run.log_info(f"In task 2... count is {count}") diff --git a/openhexa/cli/utils.py b/openhexa/cli/utils.py deleted file mode 100644 index e29c890..0000000 --- a/openhexa/cli/utils.py +++ /dev/null @@ -1,10 +0,0 @@ -import sys - -import click - - -def terminate(message: str, exception: Exception = None, err: bool = False, debug: bool = False): - click.echo(message, err=err) - if debug and exception: - raise exception - sys.exit(1) diff --git a/openhexa/sdk/__init__.py b/openhexa/sdk/__init__.py index d7d1a71..d67e936 100644 --- a/openhexa/sdk/__init__.py +++ b/openhexa/sdk/__init__.py @@ -1,6 +1,8 @@ +"""SDK package.""" + from .pipelines import current_run, parameter, pipeline from .workspaces import workspace -from .workspaces.connection import DHIS2Connection, IASOConnection, PostgreSQLConnection, GCSConnection, S3Connection +from .workspaces.connection import DHIS2Connection, GCSConnection, IASOConnection, PostgreSQLConnection, S3Connection __all__ = [ "workspace", diff --git a/openhexa/sdk/datasets/__init__.py b/openhexa/sdk/datasets/__init__.py index 7dab1bd..20efeca 100644 --- a/openhexa/sdk/datasets/__init__.py +++ b/openhexa/sdk/datasets/__init__.py @@ -1,3 +1,10 @@ +"""Datasets package. + +See https://github.com/BLSQ/openhexa/wiki/User-manual#datasets and +https://github.com/BLSQ/openhexa/wiki/Using-the-OpenHEXA-SDK#working-with-datasets for more information about OpenHEXA +dataset. +""" + from .dataset import Dataset, DatasetFile __all__ = ["Dataset", "DatasetFile"] diff --git a/openhexa/sdk/datasets/dataset.py b/openhexa/sdk/datasets/dataset.py index 2ce511b..f22b87c 100644 --- a/openhexa/sdk/datasets/dataset.py +++ b/openhexa/sdk/datasets/dataset.py @@ -1,3 +1,9 @@ +"""Dataset-related classes and functions. + +See https://github.com/BLSQ/openhexa/wiki/User-manual#datasets and +https://github.com/BLSQ/openhexa/wiki/Using-the-OpenHEXA-SDK#working-with-datasets for more information about datasets. +""" + import mimetypes import typing from os import PathLike @@ -9,6 +15,8 @@ class DatasetFile: + """Represent a single file within a dataset. Files are attached to dataset through versions.""" + _download_url = None version = None @@ -21,12 +29,14 @@ def __init__(self, version: any, id: str, uri: str, filename: str, content_type: self.created_at = created_at def read(self): + """Download the file content and return it.""" response = requests.get(self.download_url, stream=True) response.raise_for_status() return response.content @property def download_url(self): + """Build and return a pre-signed URL for the file.""" if self._download_url is None: response = graphql( """ @@ -46,10 +56,13 @@ def download_url(self): return self._download_url def __repr__(self) -> str: + """Safe representation of the dataset file.""" return f"" class VersionsIterator(Iterator): + """Custom iterator class to iterate versions using our GraphQL API.""" + def __init__(self, dataset: any, per_page: int = 10): super().__init__(per_page=per_page) @@ -90,6 +103,8 @@ def _next_page(self): class VersionFilesIterator(Iterator): + """Custom iterator class to iterate version files using our GraphQL API.""" + def __init__(self, version: any, per_page: int = 20): super().__init__(per_page=per_page) self.item_to_value = lambda x: DatasetFile( @@ -137,6 +152,8 @@ def _next_page(self): class DatasetVersion: + """Dataset files are not directly attached to a dataset, but rather to a version.""" + dataset = None def __init__(self, dataset: any, id: str, name: str, created_at: str): @@ -147,11 +164,13 @@ def __init__(self, dataset: any, id: str, name: str, created_at: str): @property def files(self): + """Build and return an iterator of files for this version.""" if self.id is None: raise ValueError("This dataset version does not have an id.") return VersionFilesIterator(version=self, per_page=50) - def get_file(self, filename: str): + def get_file(self, filename: str) -> DatasetFile: + """Get a file by name.""" data = graphql( """ query getDatasetFile($versionId: ID!, $filename: String!) { @@ -171,7 +190,7 @@ def get_file(self, filename: str): file = data["datasetVersion"]["fileByName"] if file is None: - return None + raise FileExistsError(f"The file {filename} does not exist for version {self}") return DatasetFile( version=self, @@ -186,7 +205,8 @@ def add_file( self, source: typing.Union[str, PathLike[str], typing.IO], filename: typing.Optional[str] = None, - ): + ) -> DatasetFile: + """Create a new dataset file and add it to the dataset version.""" mime_type = None if isinstance(source, (str, PathLike)): path = Path(source) @@ -246,6 +266,11 @@ def add_file( class Dataset: + """Datasets are versioned, documented files. + + See https://github.com/BLSQ/openhexa/wiki/Using-the-OpenHEXA-SDK#working-with-datasets for more information. + """ + _latest_version = None def __init__( @@ -260,9 +285,8 @@ def __init__( self.name = name self.description = description - def create_version(self, name: typing.Any): - # Check that all files exist - + def create_version(self, name: typing.Any) -> DatasetVersion: + """Build a dataset version, save it and return it.""" response = graphql( """ mutation createDatasetVersion($input: CreateDatasetVersionInput!) { @@ -299,7 +323,11 @@ def create_version(self, name: typing.Any): return self.latest_version @property - def latest_version(self): + def latest_version(self) -> typing.Optional[DatasetVersion]: + """Return the latest version, if any. + + This property method will query the backend to try to fetch the latest version. + """ if self._latest_version is None: data = graphql( """ @@ -329,8 +357,14 @@ def latest_version(self): return self._latest_version @property - def versions(self): + def versions(self) -> VersionsIterator: + """Build and return an iterator for versions.""" return VersionsIterator(dataset=self, per_page=10) def __repr__(self) -> str: + """Safe representation of the dataset.""" return f"" + + +class FileNotFound(Exception): + """Raised whenever an attempt is made to get a file that does not exist.""" diff --git a/openhexa/sdk/pipelines/__init__.py b/openhexa/sdk/pipelines/__init__.py index c2b3411..aa51bcd 100644 --- a/openhexa/sdk/pipelines/__init__.py +++ b/openhexa/sdk/pipelines/__init__.py @@ -1,11 +1,18 @@ +"""Pipelines package. + +See https://github.com/BLSQ/openhexa/wiki/User-manual#using-pipelines and +https://github.com/BLSQ/openhexa/wiki/Writing-OpenHEXA-pipelines for more information about OpenHEXA pipelines. +""" + from .parameter import parameter -from .pipeline import pipeline +from .pipeline import Pipeline, pipeline from .run import current_run from .runtime import download_pipeline, import_pipeline from .utils import get_local_workspace_config __all__ = [ "pipeline", + "Pipeline", "parameter", "current_run", "import_pipeline", diff --git a/openhexa/sdk/pipelines/parameter.py b/openhexa/sdk/pipelines/parameter.py index 53bddab..04eafc6 100644 --- a/openhexa/sdk/pipelines/parameter.py +++ b/openhexa/sdk/pipelines/parameter.py @@ -1,40 +1,43 @@ +"""Pipeline parameters classes and functions. + +See https://github.com/BLSQ/openhexa/wiki/Writing-OpenHEXA-pipelines#pipeline-parameters for more information. +""" import re import typing + +from openhexa.sdk.workspaces import workspace from openhexa.sdk.workspaces.connection import ( + Connection, + CustomConnection, DHIS2Connection, + GCSConnection, IASOConnection, PostgreSQLConnection, S3Connection, - GCSConnection, ) -from openhexa.sdk.workspaces import workspace from openhexa.sdk.workspaces.workspace import ConnectionDoesNotExist -class ParameterValueError(Exception): - pass - - class ParameterType: - """Base class for parameter types. Those parameter types are used when using the @parameter decorator""" + """Base class for parameter types. Those parameter types are used when using the @parameter decorator.""" def spec_type(self) -> str: - """Returns a type string for the specs that are sent to the backend.""" - + """Return a type string for the specs that are sent to the backend.""" raise NotImplementedError @property - def expected_type(self) -> typing.Type: + def expected_type(self) -> type: """Returns the python type expected for values.""" - raise NotImplementedError @property - def accepts_choice(self) -> bool: + def accepts_choices(self) -> bool: + """Return True only if the parameter type supports the "choices" optional argument.""" return True @property def accepts_multiple(self) -> bool: + """Return True only if the parameter type supports multiple values.""" return True @staticmethod @@ -44,12 +47,10 @@ def normalize(value: typing.Any) -> typing.Any: This can be used to handle empty values and normalize them to None, or to perform type conversions, allowing us to allow multiple input types but still normalize everything to a single type. """ - return value def validate(self, value: typing.Optional[typing.Any]) -> typing.Optional[typing.Any]: """Validate the provided value for this type.""" - if not isinstance(value, self.expected_type): raise ParameterValueError( f"Invalid type for value {value} (expected {self.expected_type}, got {type(value)})" @@ -58,23 +59,30 @@ def validate(self, value: typing.Optional[typing.Any]) -> typing.Optional[typing return value def validate_default(self, value: typing.Optional[typing.Any]): + """Validate the default value configured for this type.""" self.validate(value) def __str__(self) -> str: + """Cast parameter as string.""" return str(self.expected_type) class StringType(ParameterType): + """Type class for string parameters.""" + @property def spec_type(self) -> str: + """Return a type string for the specs that are sent to the backend.""" return "str" @property - def expected_type(self) -> typing.Type: + def expected_type(self) -> type: + """Returns the python type expected for values.""" return str @staticmethod def normalize(value: typing.Any) -> typing.Optional[str]: + """Strip leading and trailing whitespaces and convert empty strings to None.""" if isinstance(value, str): normalized_value = value.strip() else: @@ -86,6 +94,7 @@ def normalize(value: typing.Any) -> typing.Optional[str]: return normalized_value def validate_default(self, value: typing.Optional[typing.Any]): + """Validate the default value configured for this type.""" if value == "": raise ParameterValueError("Empty values are not accepted.") @@ -93,44 +102,59 @@ def validate_default(self, value: typing.Optional[typing.Any]): class Boolean(ParameterType): + """Type class for boolean parameters.""" + @property def spec_type(self) -> str: + """Return a type string for the specs that are sent to the backend.""" return "bool" @property - def expected_type(self) -> typing.Type: + def expected_type(self) -> type: + """Returns the python type expected for values.""" return bool @property - def accepts_choice(self) -> bool: + def accepts_choices(self) -> bool: + """Return a type string for the specs that are sent to the backend.""" return False @property def accepts_multiple(self) -> bool: + """Return a type string for the specs that are sent to the backend.""" return False class Integer(ParameterType): + """Type class for integer parameters.""" + @property def spec_type(self) -> str: + """Return a type string for the specs that are sent to the backend.""" return "int" @property - def expected_type(self) -> typing.Type: + def expected_type(self) -> type: + """Returns the python type expected for values.""" return int class Float(ParameterType): + """Type class for float parameters.""" + @property def spec_type(self) -> str: + """Return a type string for the specs that are sent to the backend.""" return "float" @property - def expected_type(self) -> typing.Type: + def expected_type(self) -> type: + """Returns the python type expected for values.""" return float @staticmethod def normalize(value: typing.Any) -> typing.Any: + """Normalize int values to float values if appropriate.""" if isinstance(value, int): return float(value) @@ -138,15 +162,20 @@ def normalize(value: typing.Any) -> typing.Any: class ConnectionParameterType(ParameterType): + """Abstract base class for connection parameter type classes.""" + @property - def accepts_choice(self) -> bool: + def accepts_choices(self) -> bool: + """Return True only if the parameter type supports the "choice values.""" return False @property def accepts_multiple(self) -> bool: + """Return True only if the parameter type supports multiple values.""" return False def validate_default(self, value: typing.Optional[typing.Any]): + """Validate the default value configured for this type.""" if value is None: return @@ -155,7 +184,8 @@ def validate_default(self, value: typing.Optional[typing.Any]): elif value == "": raise ParameterValueError("Empty values are not accepted.") - def validate(self, value: typing.Optional[typing.Any]) -> typing.Optional[str]: + def validate(self, value: typing.Optional[typing.Any]) -> Connection: + """Validate the provided value for this type.""" if not isinstance(value, str): raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})") @@ -164,86 +194,117 @@ def validate(self, value: typing.Optional[typing.Any]) -> typing.Optional[str]: except ConnectionDoesNotExist as e: raise ParameterValueError(str(e)) - def to_connection(self, value: str) -> typing.Any: + def to_connection(self, value: str) -> Connection: + """Build a connection instance from the provided value (which should be a connection identifier).""" raise NotImplementedError class PostgreSQLConnectionType(ConnectionParameterType): + """Type class for PostgreSQL connections.""" + @property def spec_type(self) -> str: + """Return a type string for the specs that are sent to the backend.""" return "postgresql" @property - def expected_type(self) -> typing.Type: + def expected_type(self) -> type: + """Returns the python type expected for values.""" return PostgreSQLConnectionType - def to_connection(self, value: str) -> typing.Any: + def to_connection(self, value: str) -> PostgreSQLConnection: + """Build a PostgreSQL connection instance from the provided value (which should be a connection identifier).""" return workspace.postgresql_connection(value) class S3ConnectionType(ConnectionParameterType): + """Type class for S3 connections.""" + @property def spec_type(self) -> str: + """Return a type string for the specs that are sent to the backend.""" return "s3" @property - def expected_type(self) -> typing.Type: + def expected_type(self) -> type: + """Returns the python type expected for values.""" return S3ConnectionType - def to_connection(self, value: str) -> typing.Any: + def to_connection(self, value: str) -> S3Connection: + """Build a S3 connection instance from the provided value (which should be a connection identifier).""" return workspace.s3_connection(value) class GCSConnectionType(ConnectionParameterType): + """Type class for GCS connections.""" + @property def spec_type(self) -> str: + """Return a type string for the specs that are sent to the backend.""" return "gcs" @property - def expected_type(self) -> typing.Type: + def expected_type(self) -> type: + """Returns the python type expected for values.""" return GCSConnectionType - def to_connection(self, value: str) -> typing.Any: + def to_connection(self, value: str) -> GCSConnection: + """Build a GCS connection instance from the provided value (which should be a connection identifier).""" return workspace.gcs_connection(value) class DHIS2ConnectionType(ConnectionParameterType): + """Type class for DHIS2 connections.""" + @property def spec_type(self) -> str: + """Return a type string for the specs that are sent to the backend.""" return "dhis2" @property - def expected_type(self) -> typing.Type: + def expected_type(self) -> type: + """Returns the python type expected for values.""" return DHIS2ConnectionType - def to_connection(self, value: str) -> typing.Any: + def to_connection(self, value: str) -> DHIS2Connection: + """Build a DHIS2 connection instance from the provided value (which should be a connection identifier).""" return workspace.dhis2_connection(value) class IASOConnectionType(ConnectionParameterType): + """Type class for IASO connections.""" + @property def spec_type(self) -> str: + """Return a type string for the specs that are sent to the backend.""" return "iaso" @property - def expected_type(self) -> typing.Type: + def expected_type(self) -> type: + """Returns the python type expected for values.""" return IASOConnectionType - def to_connection(self, value: str) -> typing.Any: + def to_connection(self, value: str) -> IASOConnection: + """Build a IASO connection instance from the provided value (which should be a connection identifier).""" return workspace.iaso_connection(value) class CustomConnectionType(ConnectionParameterType): + """Type class for custom connections.""" + @property def spec_type(self) -> str: + """Return a type string for the specs that are sent to the backend.""" return "custom" @property - def expected_type(self) -> typing.Type: + def expected_type(self) -> type: + """Returns the python type expected for values.""" return str - def to_connection(self, value: str) -> typing.Any: - return workspace.postgresql_connection(value) + def to_connection(self, value: str) -> CustomConnection: + """Build a custom connection instance from the provided value (which should be a connection identifier).""" + return workspace.custom_connection(value) TYPES_BY_PYTHON_TYPE = { @@ -259,10 +320,6 @@ def to_connection(self, value: str) -> typing.Any: } -class InvalidParameterError(Exception): - pass - - class Parameter: """Pipeline parameter class. Contains validation logic specs generation logic.""" @@ -270,7 +327,7 @@ def __init__( self, code: str, *, - type: typing.Union[typing.Type[str], typing.Type[int], typing.Type[bool]], + type: typing.Union[type[str], type[int], type[bool]], name: typing.Optional[str] = None, choices: typing.Optional[typing.Sequence] = None, help: typing.Optional[str] = None, @@ -280,7 +337,8 @@ def __init__( ): if re.match("^[a-z_][a-z_0-9]+$", code) is None: raise InvalidParameterError( - f"Invalid parameter code provided ({code}). Parameter must start with a letter or an underscore, and can only contain lower case letters, numbers and underscores." + f"Invalid parameter code provided ({code}). Parameter must start with a letter or an underscore, " + f"and can only contain lower case letters, numbers and underscores." ) self.code = code @@ -290,11 +348,12 @@ def __init__( except KeyError: valid_parameter_types = [str(k) for k in TYPES_BY_PYTHON_TYPE.keys()] raise InvalidParameterError( - f"Invalid parameter type provided ({type}). Valid parameter types are {', '.join(valid_parameter_types)}" + f"Invalid parameter type provided ({type}). " + f"Valid parameter types are {', '.join(valid_parameter_types)}" ) if choices is not None: - if not self.type.accepts_choice: + if not self.type.accepts_choices: raise InvalidParameterError(f"Parameters of type {self.type} don't accept choices.") elif len(choices) == 0: raise InvalidParameterError("Choices, if provided, cannot be empty.") @@ -318,8 +377,7 @@ def __init__( self.default = default def validate(self, value: typing.Any) -> typing.Any: - """Validates the provided value against the parameter, taking required / default options into account.""" - + """Validate the provided value against the parameter, taking required / default options into account.""" if self.multiple: return self._validate_multiple(value) else: @@ -383,9 +441,8 @@ def _validate_default(self, default: typing.Any, multiple: bool): except ParameterValueError: raise InvalidParameterError(f"The default value for {self.code} is not valid.") - def parameter_spec(self): - """Generates specification for this parameter, to be provided to the OpenHexa backend.""" - + def parameter_spec(self) -> dict[str, typing.Any]: + """Build specification for this parameter, to be provided to the OpenHEXA backend.""" return { "type": self.type.spec_type, "required": self.required, @@ -402,15 +459,15 @@ def parameter( code: str, *, type: typing.Union[ - typing.Type[str], - typing.Type[int], - typing.Type[bool], - typing.Type[float], - typing.Type[DHIS2Connection], - typing.Type[IASOConnection], - typing.Type[PostgreSQLConnection], - typing.Type[GCSConnection], - typing.Type[S3Connection], + type[str], + type[int], + type[bool], + type[float], + type[DHIS2Connection], + type[IASOConnection], + type[PostgreSQLConnection], + type[GCSConnection], + type[S3Connection], ], name: typing.Optional[str] = None, choices: typing.Optional[typing.Sequence] = None, @@ -419,7 +476,7 @@ def parameter( required: bool = True, multiple: bool = False, ): - """Decorator that attaches a parameter to an OpenHexa pipeline. + """Decorate a pipeline function by attaching a parameter to it.. This decorator must be used on a function decorated by the @pipeline decorator. @@ -470,17 +527,31 @@ def decorator(fun): class FunctionWithParameter: - """This class serves as a wrapper for functions decorated with the @parameter decorator.""" + """Wrapper class for pipeline functions decorated with the @parameter decorator.""" def __init__(self, function, added_parameter: Parameter): self.function = function self.parameter = added_parameter - def __call__(self, *args, **kwargs): - return self.function(*args, **kwargs) - - def get_all_parameters(self): + def get_all_parameters(self) -> list[Parameter]: + """Go through the decorators chain to find all pipeline parameters.""" if isinstance(self.function, FunctionWithParameter): return [self.parameter, *self.function.get_all_parameters()] return [self.parameter] + + def __call__(self, *args, **kwargs): + """Call the decorated pipeline function.""" + return self.function(*args, **kwargs) + + +class InvalidParameterError(Exception): + """Raised whenever parameter options (usually passed to the @parameter decorator) are invalid.""" + + pass + + +class ParameterValueError(Exception): + """Raised whenever values for a parameter provided for a pipeline run are invalid.""" + + pass diff --git a/openhexa/sdk/pipelines/pipeline.py b/openhexa/sdk/pipelines/pipeline.py index 0d2b32a..260aadb 100644 --- a/openhexa/sdk/pipelines/pipeline.py +++ b/openhexa/sdk/pipelines/pipeline.py @@ -1,4 +1,9 @@ -from __future__ import annotations +"""Main pipeline module containing the building blocks for OpenHEXA pipelines. + +See https://github.com/BLSQ/openhexa/wiki/User-manual#using-pipelines and +https://github.com/BLSQ/openhexa/wiki/Writing-OpenHEXA-pipelines for more information about OpenHEXA pipelines. +""" + import argparse import datetime @@ -14,64 +19,38 @@ import requests from multiprocess import get_context # NOQA -from openhexa.sdk.utils import Environments, get_environment +from openhexa.sdk.utils import Environment, get_environment from .parameter import ( FunctionWithParameter, Parameter, ParameterValueError, ) -from .task import PipelineWithTask +from .task import PipelineWithTask, Task from .utils import get_local_workspace_config logger = getLogger(__name__) -class PipelineConfigError(Exception): - pass - +class Pipeline: + """OpenHEXA pipeline class. -def pipeline( - code: str, *, name: str = None, timeout: int = None -) -> typing.Callable[[typing.Callable[..., typing.Any]], "Pipeline"]: - """Decorator that turns a Python function into an OpenHexa pipeline. + Pipeline are usually instantiated through the @pipeline decorator. - Parameters + Attributes ---------- code : str - An identifier for the pipeline (should be unique within the workspace where the pipeline is deployed) - name : str, optional - An optional name for the pipeline (will be used instead of the code in the web interface) - timeout : int, optional - An optional timeout, in seconds, after which the pipeline run will be terminated (if not provided, a default - timeout will be applied by the OpenHexa backend) - - Returns - ------- - typing.Callable - A decorator that returns a Pipeline - + A unique code to identify the pipeline within a workspace. + name : str + A user-friendly name for the pipeline (will be displayed in the web interface). + function: typing.Callable + The actual pipeline function. + parameters : typing.Sequence[Parameter] + A list of Parameter instance corresponding to the pipeline parameters. + timeout : int + The timeout in seconds after which the pipeline will be killed. """ - if any(c not in string.ascii_lowercase + string.digits + "_-" for c in code): - raise Exception("Pipeline code should contains only lower case letters, digits, '_' and '-'") - - def decorator(fun): - if isinstance(fun, FunctionWithParameter): - parameters = fun.get_all_parameters() - else: - parameters = [] - - return Pipeline(code, name, fun, parameters, timeout) - - return decorator - - -class PipelineRunError(Exception): - pass - - -class Pipeline: def __init__( self, code: str, @@ -87,12 +66,34 @@ def __init__( self.timeout = timeout self.tasks = [] - def task(self, function): - """task decorator""" - + def task(self, function) -> PipelineWithTask: + """Task decorator. + + Examples + -------- + >>> @pipeline("my-pipeline") + ... def my_pipeline(): + ... result_1 = task1() + ... task2(result_1) + ... + ... @my_pipeline.task + ... def task_1() -> int: + ... return 42 + ... + ... @my_pipeline.task + ... def task_2(foo: int): + ... pass + """ return PipelineWithTask(function, self) - def run(self, config: typing.Dict[str, typing.Any]): + def run(self, config: dict[str, typing.Any]): + """Run the pipeline using the provided config. + + Parameters + ---------- + config : typing.Dict[str, typing.Any] + The parameter values to use for this pipeline run. + """ now = datetime.datetime.utcnow().replace(microsecond=0).isoformat() print(f'{now} Starting pipeline "{self.code}"') @@ -116,7 +117,7 @@ def run(self, config: typing.Dict[str, typing.Any]): completed = 0 while True: - tasks = self.get_available_tasks() + tasks = self._get_available_tasks() # filter already pooled task, even if they are not finished tasks = [t for t in tasks if not t.pooled] @@ -166,12 +167,13 @@ def run(self, config: typing.Dict[str, typing.Any]): print(f'{now} Successfully completed pipeline "{self.code}"') - def get_available_tasks(self): - return [task for task in self.tasks if task.is_ready()] - - def parameters_spec(self): + def parameters_spec(self) -> list[dict[str, typing.Any]]: + """Return the individual specifications of all the parameters of this pipeline.""" return [arg.parameter_spec() for arg in self.parameters] + def _get_available_tasks(self) -> list[Task]: + return [task for task in self.tasks if task.is_ready()] + def _update_progress(self, progress: int): if self._connected: token = os.environ["HEXA_TOKEN"] @@ -193,19 +195,25 @@ def _update_progress(self, progress: int): print(f"Progress update: {progress}%") @property - def _connected(self): + def _connected(self) -> bool: env = get_environment() - return env == Environments.CLOUD_PIPELINE and "HEXA_SERVER_URL" in os.environ - def __call__(self, config: typing.Optional[typing.Dict[str, typing.Any]] = None): + return env == Environment.CLOUD_PIPELINE and "HEXA_SERVER_URL" in os.environ + + def __call__(self, config: typing.Optional[dict[str, typing.Any]] = None): + """Call the pipeline by running it, after having configured the environment. + + This method can be called with an explicit configuration. If no configuration is provided, it will parse the + command-line arguments to build it. + """ # Handle local workspace config for dev / testing, if appropriate - if get_environment() == Environments.LOCAL_PIPELINE: + if get_environment() == Environment.LOCAL_PIPELINE: os.environ.update(get_local_workspace_config(Path("/home/hexa/pipeline"))) # User can run their pipeline using `python pipeline.py`. It's considered as a standalone usage of the library. # Since we still support this use case for the moment, we'll try to load the workspace.yaml # at the path of the file - elif get_environment() == Environments.STANDALONE: + elif get_environment() == Environment.STANDALONE: os.environ.update(get_local_workspace_config(Path(sys.argv[0]).parent)) if config is None: # Called without arguments, in the pipeline file itself @@ -221,7 +229,7 @@ def __call__(self, config: typing.Optional[typing.Dict[str, typing.Any]] = None) "config file with the --config-file/-f argument." ) if args.config_file is not None: - with open(args.config_file, "r") as cf: + with open(args.config_file) as cf: try: config = json.load(cf) except json.JSONDecodeError: @@ -236,3 +244,59 @@ def __call__(self, config: typing.Optional[typing.Dict[str, typing.Any]] = None) config = {} self.run(config) + + +def pipeline( + code: str, *, name: str = None, timeout: int = None +) -> typing.Callable[[typing.Callable[..., typing.Any]], Pipeline]: + """Decorate a Python function as an OpenHEXA pipeline. + + Parameters + ---------- + code : str + An identifier for the pipeline (should be unique within the workspace where the pipeline is deployed) + name : str, optional + An optional name for the pipeline (will be used instead of the code in the web interface) + timeout : int, optional + An optional timeout, in seconds, after which the pipeline run will be terminated (if not provided, a default + timeout will be applied by the OpenHEXA backend) + + Returns + ------- + typing.Callable + A decorator that returns a Pipeline + + Examples + -------- + >>> @pipeline("my-pipeline") + ... def my_pipeline(): + ... a_task() + ... + ... @my_pipeline.task + ... def a_task() -> int: + ... return 42 + """ + if any(c not in string.ascii_lowercase + string.digits + "_-" for c in code): + raise Exception("Pipeline code should contains only lower case letters, digits, '_' and '-'") + + def decorator(fun): + if isinstance(fun, FunctionWithParameter): + parameters = fun.get_all_parameters() + else: + parameters = [] + + return Pipeline(code, name, fun, parameters, timeout) + + return decorator + + +class PipelineConfigError(Exception): + """Error raised whenver the config passed to the pipeline run method is invalid.""" + + pass + + +class PipelineRunError(Exception): + """Generic pipeline runtime error, raised whenever user code raises an exception.""" + + pass diff --git a/openhexa/sdk/pipelines/run.py b/openhexa/sdk/pipelines/run.py index c904e14..c79bdaa 100644 --- a/openhexa/sdk/pipelines/run.py +++ b/openhexa/sdk/pipelines/run.py @@ -1,22 +1,29 @@ +"""Pipeline run module.""" + import datetime import os import typing -from pathlib import Path -from openhexa.sdk.utils import Environments, get_environment, graphql +from openhexa.sdk.utils import Environment, get_environment, graphql from openhexa.sdk.workspaces import workspace class CurrentRun: + """Represents the current run of a pipeline. + + CurrentRun instances allow pipeline developers to interact with the OpenHEXA backend, by sending messages and + adding outputs that will be available through the web interface. + """ + @property def _connected(self): return "HEXA_SERVER_URL" in os.environ - @property - def tmp_path(self): - return Path("~/tmp/") - def add_file_output(self, path: str): + """Record a run output for a file creation operation. + + This output will be visible in the web interface, on the pipeline run page. + """ stripped_path = path.replace(workspace.files_path, "") name = stripped_path.strip("/") if self._connected: @@ -37,6 +44,10 @@ def add_file_output(self, path: str): print(f"Sending output with path {stripped_path}") def add_database_output(self, table_name: str): + """Record a run output for a database operation. + + This output will be visible in the web interface, on the pipeline run page. + """ if self._connected: graphql( """ @@ -55,18 +66,23 @@ def add_database_output(self, table_name: str): print(f"Sending output with table_name {table_name}") def log_debug(self, message: str): + """Log a message with the DEBUG priority.""" self._log_message("DEBUG", message) def log_info(self, message: str): + """Log a message with the INFO priority.""" self._log_message("INFO", message) def log_warning(self, message: str): + """Log a message with the WARNING priority.""" self._log_message("WARNING", message) def log_error(self, message: str): + """Log a message with the ERROR priority.""" self._log_message("ERROR", message) def log_critical(self, message: str): + """Log a message with the CRITICAL priority.""" self._log_message("CRITICAL", message) def _log_message( @@ -91,7 +107,7 @@ def _log_message( print(now, priority, message) -if get_environment() == Environments.CLOUD_JUPYTER: +if get_environment() == Environment.CLOUD_JUPYTER: current_run = None else: current_run = CurrentRun() diff --git a/openhexa/sdk/pipelines/runtime.py b/openhexa/sdk/pipelines/runtime.py index bb4fdfe..d769d54 100644 --- a/openhexa/sdk/pipelines/runtime.py +++ b/openhexa/sdk/pipelines/runtime.py @@ -1,3 +1,5 @@ +"""Utilities used by containerized pipeline runners to import and download pipelines.""" + import base64 import importlib import io @@ -11,6 +13,7 @@ def import_pipeline(pipeline_dir_path: str): + """Import pipeline code within provided path using importlib.""" pipeline_dir = os.path.abspath(pipeline_dir_path) sys.path.append(pipeline_dir) pipeline_package = importlib.import_module("pipeline") @@ -19,7 +22,8 @@ def import_pipeline(pipeline_dir_path: str): return pipeline -def download_pipeline(url: str, token: str, run_id: str, target_dir): +def download_pipeline(url: str, token: str, run_id: str, target_dir: str): + """Download pipeline code and unzip it into the target directory.""" r = requests.post( url + "/graphql/", headers={"Authorization": f"Bearer {token}"}, diff --git a/openhexa/sdk/pipelines/task.py b/openhexa/sdk/pipelines/task.py index 31b9b1d..f545abf 100644 --- a/openhexa/sdk/pipelines/task.py +++ b/openhexa/sdk/pipelines/task.py @@ -1,3 +1,8 @@ +"""Classes and functions related to pipeline tasks. + +See https://github.com/BLSQ/openhexa/wiki/Writing-OpenHEXA-pipelines#pipelines-and-tasks for more information. +""" + from __future__ import annotations import datetime @@ -7,6 +12,11 @@ class TaskCom: + """Lightweight data transfer object allowing tasks to communicate. + + TaskCom instances also allow us to build the pipeline dependency graph. + """ + def __init__(self, task): self.result = task.result self.start_time = task.start_time @@ -14,6 +24,11 @@ def __init__(self, task): class Task: + """Tasks are pipeline data processing code units. + + See https://github.com/BLSQ/openhexa/wiki/Writing-OpenHEXA-pipelines#pipelines-and-tasks for more information. + """ + def __init__(self, function: typing.Callable): self.name = function.__name__ self.compute = function @@ -26,27 +41,11 @@ def __init__(self, function: typing.Callable): self.active = False self.pooled = False - def __call__(self, *task_args, **task_kwargs): - self.active = True # uncalled tasks will be skipped - # check that all inputs are tasks - self.task_args = task_args - self.task_kwargs = task_kwargs - return self - - def __repr__(self): - return self.name - - def get_node_inputs(self): - inputs = [] - for a in self.task_args: - if issubclass(type(a), Task): - inputs.append(a) - for k, a in self.task_kwargs.items(): - if issubclass(type(a), Task): - inputs.append(a) - return inputs + def is_ready(self) -> bool: + """Determine whether the task is ready to be run. - def is_ready(self): + This involves checking whether tasks higher up in the dependency graph have been executed. + """ if not self.active: return False @@ -59,24 +58,32 @@ def is_ready(self): return True if self.end_time is None else False - def get_tasks_ready(self): + def get_ready_tasks(self) -> list[Task]: + """Find and return all tasks that can be launched at this point in time.""" tasks = [] for a in self.task_args: if issubclass(type(a), Task): if a.is_ready(): tasks.append(a) else: - tasks += a.get_tasks_ready() + tasks += a.get_ready_tasks() for k, a in self.task_kwargs.items(): if issubclass(type(a), Task): if a.is_ready(): tasks.append(a) else: - tasks += a.get_tasks_ready() + tasks += a.get_ready_tasks() return list(set(tasks)) - def run(self): + def run(self) -> TaskCom: + """Run the task. + + Returns + ------- + TaskCom + A TaskCom instance which can in turn be passed to other tasks. + """ if self.end_time: # already executed, return previous result return self.result @@ -106,22 +113,33 @@ def run(self): # done! return TaskCom(self) - def stateless_run(self): - self.result = None - self.start_time, self.end_time = None, None - return self.run() + def __call__(self, *task_args, **task_kwargs): + """Wrap the task with args and kwargs and return it.""" + self.active = True # uncalled tasks will be skipped + # check that all inputs are tasks + self.task_args = task_args + self.task_kwargs = task_kwargs + + return self + + def __repr__(self): + """Representation of the task using its name.""" + return self.name class PipelineWithTask: + """Pipeline with attached tasks, usually through the @task decorator.""" + def __init__( self, function: typing.Callable, - pipeline: openhexa.sdk.pipelines.pipeline.Pipeline, + pipeline: openhexa.sdk.pipelines.Pipeline, ): self.function = function self.pipeline = pipeline - def __call__(self, *task_args, **task_kwargs): + def __call__(self, *task_args, **task_kwargs) -> Task: + """Attach the new task to the decorated pipeline and return it.""" task = Task(self.function)(*task_args, **task_kwargs) self.pipeline.tasks.append(task) return task diff --git a/openhexa/sdk/pipelines/utils.py b/openhexa/sdk/pipelines/utils.py index 6e70ee0..7fae0fc 100644 --- a/openhexa/sdk/pipelines/utils.py +++ b/openhexa/sdk/pipelines/utils.py @@ -1,3 +1,5 @@ +"""Utilities for running local pipelines.""" + from pathlib import Path from tempfile import mkdtemp @@ -6,6 +8,8 @@ class LocalWorkspaceConfigError(Exception): + """Raised whenever the local workspace config file does not exist or is invalid.""" + pass @@ -18,7 +22,6 @@ def get_local_workspace_config(path: Path): This is obviously brittle as it relies on setting the correct env variables keys, any changes upstream must be reflected here. """ - env_vars = {} # This will only work when running the pipeline using "python pipeline.py" @@ -29,7 +32,7 @@ def get_local_workspace_config(path: Path): "To work with pipelines locally, you need a workspace.yaml file in the same directory as your pipeline file" ) - with open(local_workspace_config_path.resolve(), "r") as local_workspace_config_file: + with open(local_workspace_config_path.resolve()) as local_workspace_config_file: local_workspace_config = yaml.safe_load(local_workspace_config_file) # Database config if "database" in local_workspace_config: diff --git a/openhexa/sdk/utils.py b/openhexa/sdk/utils.py index fb983b2..4bd9cf1 100644 --- a/openhexa/sdk/utils.py +++ b/openhexa/sdk/utils.py @@ -1,3 +1,5 @@ +"""Miscellaneous utility functions.""" + import abc import enum import os @@ -6,7 +8,9 @@ import requests -class Environments(enum.Enum): +class Environment(enum.Enum): + """Enumeration of supported runtime environments.""" + LOCAL_PIPELINE = "LOCAL_PIPELINE" CLOUD_PIPELINE = "CLOUD_PIPELINE" CLOUD_JUPYTER = "CLOUD_JUPYTER" @@ -14,13 +18,17 @@ class Environments(enum.Enum): def get_environment(): + """Get the environment from the HEXA_ENVIRONMENT (see Environment enum).""" env = os.environ.get("HEXA_ENVIRONMENT", "STANDALONE").upper() - if env not in Environments.__members__: + + try: + return Environment[env] + except KeyError: raise ValueError(f"Invalid environment: {env}") - return Environments[env] -def graphql(operation: str, variables: typing.Optional[typing.Dict[str, typing.Any]] = None): +def graphql(operation: str, variables: typing.Optional[dict[str, typing.Any]] = None) -> dict[str, typing.Any]: + """Performa GraphQL query.""" auth_token = os.environ[ "HEXA_TOKEN" ] # Works for notebooks with the membership token and pipelines with the run token @@ -42,7 +50,7 @@ def graphql(operation: str, variables: typing.Optional[typing.Dict[str, typing.A return body["data"] -class Iterator(object, metaclass=abc.ABCMeta): +class Iterator(metaclass=abc.ABCMeta): """A generic class for iterating through API list responses.""" def __init__( @@ -68,42 +76,35 @@ def __init__( """int: The total number of results fetched so far.""" def _items_iter(self): - """Iterator for each item returned.""" for page in self._page_iter(increment=False): for item in page: self.num_results += 1 yield item - def __iter__(self): - """Iterator for each item returned. - - Returns: - types.GeneratorType[Any]: A generator of items from the API. - - Raises: - ValueError: If the iterator has already been started. - """ + def __iter__(self) -> typing.Generator[typing.Any, None, None]: + """Implement __iter().""" if self._started: raise ValueError("Iterator has already started", self) self._started = True + return self._items_iter() def __next__(self): + """Implement next().""" if self.__active_iterator is None: self.__active_iterator = iter(self) - return next(self.__active_iterator) - def _page_iter(self, increment): - """Generator of pages of API responses. + return next(self.__active_iterator) - Args: - increment (bool): Flag indicating if the total number of results - should be incremented on each page. This is useful since a page - iterator will want to increment by results per page while an - items iterator will want to increment per item. + def _page_iter(self, increment: bool): + """Generate pages of API responses. - Yields: - Page: each page of items from the API. + Parameters + ---------- + increment : bool + Flag indicating if the total number of results should be incremented on each page. + This is useful since a page iterator will want to increment by results per page while an + items iterator will want to increment per item. """ page = self._next_page() while page is not None: @@ -120,24 +121,25 @@ def _next_page(self): This does nothing and is intended to be over-ridden by subclasses to return the next :class:`Page`. - Raises: + Raises + ------ NotImplementedError: Always, this method is abstract. """ raise NotImplementedError -class Page(object): +class Page: """Single page of results in an iterator. - Args: - parent (Iterator): The iterator that owns - the current page. - items (Sequence[Any]): An iterable (that also defines __len__) of items - from a raw API response. - item_to_value (Callable[google.api_core.page_iterator.Iterator, Any]): - Callable to convert an item from the type in the raw API response - into the native object. Will be called with the iterator and a - single item. + Parameters + ---------- + parent : Iterator + The iterator that owns the current page. + items: Sequence[Any] + An iterable (that also defines __len__) of items from a raw API response. + item_to_value: Callable[dict[str, Any], Any]: + Callable to convert an item from the type in the raw API response into the native object. + Will be called with the iterator and a single item. """ def __init__(self, parent, items, item_to_value): @@ -158,7 +160,7 @@ def remaining(self): return self._remaining def __iter__(self): - """The :class:`Page` is an iterator of items.""" + """Implement __iter__().""" return self def __next__(self): @@ -171,7 +173,8 @@ def __next__(self): return result -def read_content(source: typing.Union[str, os.PathLike[str], typing.IO], encoding: str = "utf-8") -> str: +def read_content(source: typing.Union[str, os.PathLike[str], typing.IO], encoding: str = "utf-8") -> bytes: + """Read file content and return it as bytes.""" # If source is a string or PathLike object if isinstance(source, (str, os.PathLike)): with open(os.fspath(source), "rb") as f: diff --git a/openhexa/sdk/workspaces/__init__.py b/openhexa/sdk/workspaces/__init__.py index 8faffd8..94bf94c 100644 --- a/openhexa/sdk/workspaces/__init__.py +++ b/openhexa/sdk/workspaces/__init__.py @@ -1,3 +1,8 @@ +"""Workspaces package. + +See https://github.com/BLSQ/openhexa/wiki/User-manual#about-workspaces for more information about OpenHEXA workspaces. +""" + from .workspace import workspace __all__ = ["workspace"] diff --git a/openhexa/sdk/workspaces/connection.py b/openhexa/sdk/workspaces/connection.py index 51b79af..2f6c87a 100644 --- a/openhexa/sdk/workspaces/connection.py +++ b/openhexa/sdk/workspaces/connection.py @@ -1,18 +1,38 @@ +"""Connection test module.""" + import dataclasses @dataclasses.dataclass -class DHIS2Connection: +class Connection: + """Abstract base class for connections.""" + + pass + + +@dataclasses.dataclass +class DHIS2Connection(Connection): + """DHIS2 connection. + + See https://docs.dhis2.org/ for more information. + """ + url: str username: str password: str def __repr__(self): + """Safe representation of the DHIS2 connection (no credentials).""" return f"DHIS2Connection(url='{self.url}', username='{self.username}')" @dataclasses.dataclass -class PostgreSQLConnection: +class PostgreSQLConnection(Connection): + """PostgreSQL database connection. + + See https://www.postgresql.org/docs/ for more information. + """ + host: str port: int username: str @@ -20,37 +40,76 @@ class PostgreSQLConnection: database_name: str def __repr__(self): - return f"PostgreSQLConnection(host='{self.host}', port='{self.port}', username='{self.username}', database_name='{self.database_name}')" + """Safe representation of the PostgreSQL connection (no credentials).""" + return ( + f"PostgreSQLConnection(host='{self.host}', port='{self.port}', username='{self.username}', " + f"database_name='{self.database_name}')" + ) @property def url(self): + """Provide a URL to the PostgreSQL database. + + The URL follows the official PostgreSQL specification. + (See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for more information) + """ return f"postgresql://{self.username}:{self.password}" f"@{self.host}:{self.port}/{self.database_name}" @dataclasses.dataclass -class S3Connection: +class S3Connection(Connection): + """AWS S3 connection. + + See https://docs.aws.amazon.com/s3/ for more information. + """ + access_key_id: str secret_access_key: str bucket_name: str def __repr__(self): + """Safe representation of the S3 connection (no credentials).""" return f"S3Connection(bucket_name='{self.bucket_name}')" @dataclasses.dataclass -class GCSConnection: +class GCSConnection(Connection): + """Google Cloud Storage connection. + + See https://cloud.google.com/storage/docs for more information. + """ + service_account_key: str bucket_name: str def __repr__(self): + """Safe representation of the GCS connection (no credentials).""" return f"GCSConnection(bucket_name='{self.bucket_name}')" +@dataclasses.dataclass +class CustomConnection(Connection): + """Marker class for custom connections. + + The actual class will be built dynamically through the Workspace.custom_connection() method. + """ + + def __repr__(self): + """Safe representation of the custom connection (no credentials).""" + return f"CustomConnection(name='{self.__class__.__name__.lower()}')" + + @dataclasses.dataclass class IASOConnection: + """IASO connection. + + See https://github.com/BLSQ/iaso for more information. + """ + url: str username: str password: str def __repr__(self): + """Safe representation of the IASO connection (no credentials).""" return f"IASOConnection(url='{self.url}', username='{self.username}')" diff --git a/openhexa/sdk/workspaces/workspace.py b/openhexa/sdk/workspaces/workspace.py index 743ec5c..c5ee69f 100644 --- a/openhexa/sdk/workspaces/workspace.py +++ b/openhexa/sdk/workspaces/workspace.py @@ -1,3 +1,8 @@ +"""Workspace-related classes and functions. + +See https://github.com/BLSQ/openhexa/wiki/User-manual#about-workspaces for more information. +""" + import os from dataclasses import make_dataclass from warnings import warn @@ -7,6 +12,7 @@ from ..datasets import Dataset from ..utils import graphql from .connection import ( + CustomConnection, DHIS2Connection, GCSConnection, IASOConnection, @@ -16,30 +22,41 @@ class WorkspaceConfigError(Exception): + """Raised whenever the system cannot find an environment variable required to configure the current workspace.""" + pass class ConnectionDoesNotExist(Exception): + """Raised whenever an attempt is made to get a connection through an invalid identifier.""" + pass class CurrentWorkspace: + """Represents the currently configured OpenHEXA workspace, with its filesystem, database and connections.""" + @property - def _token(self): + def _token(self) -> str: try: return os.environ["HEXA_TOKEN"] except KeyError: - raise WorkspaceConfigError("Workspace's token is not available in this environment.") + raise WorkspaceConfigError("The workspace token is not available in this environment.") @property - def slug(self): + def slug(self) -> str: + """The unique slug of the workspace. + + Slugs are used to identify the workspace. + """ try: return os.environ["HEXA_WORKSPACE"] except KeyError: - raise WorkspaceConfigError("Workspace's slug is not available in this environment.") + raise WorkspaceConfigError("The workspace slug is not available in this environment.") @property - def database_host(self): + def database_host(self) -> str: + """The workspace database host.""" try: return os.environ["WORKSPACE_DATABASE_HOST"] except KeyError: @@ -49,7 +66,8 @@ def database_host(self): ) @property - def database_username(self): + def database_username(self) -> str: + """The workspace database username.""" try: return os.environ["WORKSPACE_DATABASE_USERNAME"] except KeyError: @@ -60,6 +78,7 @@ def database_username(self): @property def database_password(self): + """The workspace database password.""" try: return os.environ.get("WORKSPACE_DATABASE_PASSWORD") except KeyError: @@ -70,6 +89,7 @@ def database_password(self): @property def database_name(self): + """The workspace database name.""" try: return os.environ["WORKSPACE_DATABASE_DB_NAME"] except KeyError: @@ -80,6 +100,7 @@ def database_name(self): @property def database_port(self): + """The workspace database port.""" try: return int(os.environ["WORKSPACE_DATABASE_PORT"]) except KeyError: @@ -90,6 +111,11 @@ def database_port(self): @property def database_url(self): + """The workspace database URL. + + The URL follows the official PostgreSQL specification. + (See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for more information) + """ return ( f"postgresql://{self.database_username}:{self.database_password}" f"@{self.database_host}:{self.database_port}/{self.database_name}" @@ -97,21 +123,41 @@ def database_url(self): @property def files_path(self) -> str: - """Return the base path to the filesystem, without trailing slash""" + """The base path to the filesystem, without trailing slash. + Examples + -------- + >>> f"{workspace.files_path}/some/path" + /home/hexa/workspace/some/path + """ # FIXME: This is a hack to make the SDK work in the context of the `python pipeline.py` command. # We can remove this once we deprecate this way of running pipelines return os.environ["WORKSPACE_FILES_PATH"] if "WORKSPACE_FILES_PATH" in os.environ else "/home/hexa/workspace" @property def tmp_path(self) -> str: - """Return the base path to the tmp directory, without trailing slash""" + """The base path to the tmp directory, without trailing slash. + Examples + -------- + >>> f"{workspace.tmp_path}/some/path" + /home/hexa/tmp/some/path + """ # FIXME: This is a hack to make the SDK work in the context of the `python pipeline.py` command. # We can remove this once we deprecate this way of running pipelines return os.environ["WORKSPACE_TMP_PATH"] if "WORKSPACE_TMP_PATH" in os.environ else "/home/hexa/tmp" - def dhis2_connection(self, identifier: str = None, slug: str = None) -> DHIS2Connection: + @staticmethod + def dhis2_connection(identifier: str = None, slug: str = None) -> DHIS2Connection: + """Get a DHIS2 connection by its identifier. + + Parameters + ---------- + identifier : str + The identifier of the connection in the OpenHEXA backend + slug : str + Deprecated, same as identifier + """ identifier = identifier or slug if slug is not None: warn("'slug' is deprecated. Use 'identifier' instead.", DeprecationWarning, stacklevel=2) @@ -125,7 +171,17 @@ def dhis2_connection(self, identifier: str = None, slug: str = None) -> DHIS2Con return DHIS2Connection(url=url, username=username, password=password) - def postgresql_connection(self, identifier: str = None, slug: str = None) -> PostgreSQLConnection: + @staticmethod + def postgresql_connection(identifier: str = None, slug: str = None) -> PostgreSQLConnection: + """Get a PostgreSQL connection by its identifier. + + Parameters + ---------- + identifier : str + The identifier of the connection in the OpenHEXA backend + slug : str + Deprecated, same as identifier + """ identifier = identifier or slug if slug is not None: warn("'slug' is deprecated. Use 'identifier' instead.", DeprecationWarning, stacklevel=2) @@ -147,7 +203,17 @@ def postgresql_connection(self, identifier: str = None, slug: str = None) -> Pos database_name=dbname, ) - def s3_connection(self, identifier: str = None, slug: str = None) -> S3Connection: + @staticmethod + def s3_connection(identifier: str = None, slug: str = None) -> S3Connection: + """Get an AWS S3 connection by its identifier. + + Parameters + ---------- + identifier : str + The identifier of the connection in the OpenHEXA backend + slug : str + Deprecated, same as identifier + """ identifier = identifier or slug if slug is not None: warn("'slug' is deprecated. Use 'identifier' instead.", DeprecationWarning, stacklevel=2) @@ -165,7 +231,17 @@ def s3_connection(self, identifier: str = None, slug: str = None) -> S3Connectio bucket_name=bucket_name, ) - def gcs_connection(self, identifier: str = None, slug: str = None) -> GCSConnection: + @staticmethod + def gcs_connection(identifier: str = None, slug: str = None) -> GCSConnection: + """Get a Google Cloud Storage connection by its identifier. + + Parameters + ---------- + identifier : str + The identifier of the connection in the OpenHEXA backend + slug : str + Deprecated, same as identifier + """ identifier = identifier or slug if slug is not None: warn("'slug' is deprecated. Use 'identifier' instead.", DeprecationWarning, stacklevel=2) @@ -181,7 +257,17 @@ def gcs_connection(self, identifier: str = None, slug: str = None) -> GCSConnect bucket_name=bucket_name, ) - def iaso_connection(self, identifier: str = None, slug: str = None) -> IASOConnection: + @staticmethod + def iaso_connection(identifier: str = None, slug: str = None) -> IASOConnection: + """Get a IASO connection by it identifier. + + Parameters + ---------- + identifier : str + The identifier of the connection in the OpenHEXA backend + slug : str + Deprecated, same as identifier + """ identifier = identifier or slug if slug is not None: warn("'slug' is deprecated. Use 'identifier' instead.", DeprecationWarning, stacklevel=2) @@ -195,30 +281,48 @@ def iaso_connection(self, identifier: str = None, slug: str = None) -> IASOConne return IASOConnection(url=url, username=username, password=password) - def custom_connection(self, identifier: str = None, slug: str = None): + @staticmethod + def custom_connection(identifier: str = None, slug: str = None) -> CustomConnection: + """Get a custom connection by its identifier. + + Parameters + ---------- + identifier : str + The identifier of the connection in the OpenHEXA backend + slug : str + Deprecated, same as identifier + """ identifier = identifier or slug if slug is not None: warn("'slug' is deprecated. Use 'identifier' instead.", DeprecationWarning, stacklevel=2) - identifier = identifier.lower() - env_variable_prefix = stringcase.constcase(identifier) + env_variable_prefix = stringcase.constcase(identifier.lower()) fields = {} for key, value in os.environ.items(): if key.startswith(env_variable_prefix): field_key = key[len(f"{env_variable_prefix}_") :].lower() fields[field_key] = value - dataclass = make_dataclass(identifier, fields.keys()) + if len(fields) == 0: + raise ConnectionDoesNotExist(f'No custom connection for "{identifier}"') - class CustomConnection(dataclass): - def __repr__(self): - return f"CustomConnection(name='{identifier}')" + dataclass = make_dataclass( + stringcase.pascalcase(identifier), fields.keys(), bases=(CustomConnection,), repr=False + ) - return CustomConnection(**fields) + return dataclass(**fields) def create_dataset(self, identifier: str, name: str, description: str): + """Create a new dataset.""" raise NotImplementedError("create_dataset is not implemented yet.") - def get_dataset(self, identifier: str): + def get_dataset(self, identifier: str) -> Dataset: + """Get a dataset by its identifier. + + Parameters + ---------- + identifier : str + The identifier of the dataset in the OpenHEXA backend + """ response = graphql( """ query getDataset($datasetSlug: String!, $workspaceSlug: String!) { diff --git a/pyproject.toml b/pyproject.toml index b5cbb60..c11a9ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta" [project] name= "openhexa.sdk" version = "0.1.33" -description = "OpenHexa SDK" +description = "OpenHEXA SDK" authors = [ { name = "Bluesquare", email = "dev@bluesquarehub.com"} @@ -35,7 +35,7 @@ openhexa = "openhexa.cli:app" [project.optional-dependencies] -dev = [ "ruff~=0.0.278", "pytest~=7.3.0", "build>=0.10,<1.1", "pytest-cov>=4.0,<4.2" , "black~=23.7.0", "pre-commit"] +dev = [ "ruff~=0.0.278", "pytest~=7.3.0", "build>=0.10,<1.1", "pytest-cov>=4.0,<4.2" , "pre-commit"] examples= ["geopandas~=0.12.2", "pandas~=2.0.0", "rasterio~=1.3.6", "rasterstats>=0.18,<0.20", "setuptools>=67.6.1,<69.1.0", "SQLAlchemy~=2.0.9", "psycopg2"] @@ -53,26 +53,29 @@ namespaces = true [tool.setuptools] include-package-data = true -[tool.black] -line-length = 120 -exclude = ''' -( - /( - \.eggs # exclude a few common directories in the - | \.git # root of the project - )/ - | node_modules -) -''' - [tool.ruff] line-length = 120 ignore = ["E501"] +[tool.ruff.lint] +extend-select = [ + "D", # pydocstyle + "I", # isort + "UP", # pyupgrade + "ANN", # flake8-annotations +] +# Disable all "missing docstrings" and "missing type annotations" rules for now +# TODO: enable +ignore = ["ANN001", "ANN002", "ANN003", "ANN101", "ANN102", "ANN201", "ANN202", "ANN204", "ANN205", "ANN401"] + [tool.ruff.pycodestyle] max-doc-length = 120 +[tool.ruff.lint.pydocstyle] +convention = "numpy" # Accepts: "google", "numpy", or "pep257". + + [tool.coverage.report] exclude_lines = [ # Have to re-enable the standard pragma diff --git a/tests/__init__.py b/tests/__init__.py index e69de29..38bb211 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Test package.""" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..e86c6c3 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,26 @@ +"""Module-level test fixtures.""" +from importlib import reload + +import pytest + +import openhexa.sdk + + +@pytest.fixture(scope="function") +def workspace(): + """Build workspace fixture.""" + from openhexa.sdk import workspace as global_workspace + + reload(openhexa.sdk) + + return global_workspace + + +@pytest.fixture(scope="function") +def current_run(): + """Build current run fixture.""" + from openhexa.sdk import current_run as global_current_run + + reload(openhexa.sdk) + + return global_current_run diff --git a/tests/test_api.py b/tests/test_api.py index a7913c6..d3c2335 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,26 +1,29 @@ +"""API interactions test module.""" + import configparser import os -import pytest import shutil -import yaml import uuid - -from click.testing import CliRunner -from openhexa.cli.api import upload_pipeline -from openhexa.cli.cli import pipelines_init, pipelines_delete from pathlib import Path - from unittest import mock from zipfile import ZipFile +import pytest +import yaml +from click.testing import CliRunner + +from openhexa.cli.api import upload_pipeline +from openhexa.cli.cli import pipelines_delete, pipelines_init + def test_upload_pipeline(): + """Test upload API call.""" # to enable zip file creation config = configparser.ConfigParser() config["openhexa"] = {"debug": True, "current_workspace": "test_workspace"} runner = CliRunner() - runner.invoke(pipelines_init, ["test_pipelines"]) + runner.invoke(pipelines_init, ["test_pipelines"]) # NOQA pipeline_dir = Path.cwd() / "test_pipelines" pipeline_zip_file_dir = Path.cwd() / "pipeline.zip" @@ -41,17 +44,18 @@ def test_upload_pipeline(): def test_upload_pipeline_custom_files_path(): + """Test upload API call (custom file path).""" # to enable zip file creation config = configparser.ConfigParser() config["openhexa"] = {"debug": True, "current_workspace": "test_workspace"} runner = CliRunner() - runner.invoke(pipelines_init, ["test_pipelines"]) + runner.invoke(pipelines_init, ["test_pipelines"]) # NOQA pipeline_dir = Path.cwd() / "test_pipelines" pipeline_zip_file_dir = Path.cwd() / "pipeline.zip" (pipeline_dir / Path("data")).mkdir() - # setup a custom path for files location in workspace.yaml + # set up a custom path for files location in workspace.yaml pipeline_configs = {"files": {"path": "./data"}} with open(pipeline_dir / Path("workspace.yaml"), "w") as pipeline_configs_file: @@ -74,6 +78,7 @@ def test_upload_pipeline_custom_files_path(): def test_delete_pipeline_not_in_workspace(): + """Test delete pipeline (pipeline does not exist).""" config = configparser.ConfigParser() config["openhexa"] = {"debug": True, "current_workspace": "test_workspace"} @@ -83,12 +88,13 @@ def test_delete_pipeline_not_in_workspace(): runner = CliRunner() mocked_config.return_value = config mocked_graphql_client.return_value = {"pipelineByCode": None} - r = runner.invoke(pipelines_delete, ["test_pipelines"], input="test_pipelines") + r = runner.invoke(pipelines_delete, ["test_pipelines"], input="test_pipelines") # NOQA assert r.output == "Pipeline test_pipelines does not exist in workspace test_workspace\n" def test_delete_pipeline_confirm_code_invalid(): + """Test delete pipeline with an invalid confirmation code.""" config = configparser.ConfigParser() config["openhexa"] = {"debug": True, "current_workspace": "test_workspace"} @@ -105,12 +111,13 @@ def test_delete_pipeline_confirm_code_invalid(): "currentVersion": {"number": 1}, } } - r = runner.invoke(pipelines_delete, ["test_pipelines"], input="test_pipeline") + r = runner.invoke(pipelines_delete, ["test_pipelines"], input="test_pipeline") # NOQA # "Pipeline code and confirmation are different assert r.exit_code == 1 def test_delete_pipeline(): + """Happy path for delete pipeline API call.""" config = configparser.ConfigParser() config["openhexa"] = {"debug": True, "current_workspace": "test_workspace"} @@ -126,6 +133,6 @@ def test_delete_pipeline(): "currentVersion": {"number": 1}, } mocked_delete_pipeline.return_value = True - r = runner.invoke(pipelines_delete, ["test_pipelines"], input="test_pipelines") + r = runner.invoke(pipelines_delete, ["test_pipelines"], input="test_pipelines") # NOQA assert r.exit_code == 0 diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 4878b01..f913fff 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,3 +1,5 @@ +"""Dataset test module.""" + import os from unittest import TestCase from unittest.mock import patch @@ -6,12 +8,15 @@ class DatasetTest(TestCase): + """Dataset test class.""" + @patch.dict( os.environ, {"HEXA_WORKSPACE": "workspace-slug", "HEXA_TOKEN": "token", "HEXA_SERVER_URL": "server"}, ) @patch("openhexa.sdk.datasets.dataset.graphql") def test_create_dataset_version(self, mock_graphql): + """Ensure that dataset versions can be created.""" d = Dataset( id="id", slug="my-dataset", diff --git a/tests/test_environment.py b/tests/test_environment.py index 991f752..ef00060 100644 --- a/tests/test_environment.py +++ b/tests/test_environment.py @@ -1,3 +1,5 @@ +"""Collection of tests related to environment variables handling.""" + import os from unittest.mock import Mock, patch @@ -5,14 +7,14 @@ @patch.dict(os.environ, {"WORKSPACE_FILES_PATH": "/workspace/path"}) -def test_workspace_local_files_path(): - from openhexa.sdk.workspaces import workspace - +def test_workspace_local_files_path(workspace): + """Test workspace files_path property with explicit environment variable.""" assert workspace.files_path == "/workspace/path" @patch.dict(os.environ, {"HEXA_ENVIRONMENT": "LOCAL_PIPELINE"}) -def test_workspace_pipeline_files_path(): +def test_workspace_pipeline_files_path(workspace): + """Test workspace files_path property in local mode.""" from openhexa.sdk.workspaces import workspace assert workspace.files_path == "/home/hexa/workspace" @@ -22,21 +24,19 @@ def test_workspace_pipeline_files_path(): os.environ, {"HEXA_ENVIRONMENT": "CLOUD_PIPELINE", "HEXA_SERVER_URL": "https://test.openhexa.org"}, ) -def test_connected(): - from openhexa.sdk.pipelines import current_run - +def test_connected(current_run): + """Test run _connected property in CLOUD_PIPELINE mode.""" assert current_run._connected is True @patch.dict( os.environ, { - "HEXA_ENVIRONMENT": "LOCAL_PIPELINE", + "HEXA_ENVIRONMENT": "CLOUD_PIPELINE", }, ) -def test_not_connected(): - from openhexa.sdk.pipelines import current_run - +def test_not_connected_missing_url(current_run): + """Test run _connected property in CLOUD_PIPELINE mode (missing HEXA_SERVER_URL).""" assert current_run._connected is False @@ -46,9 +46,8 @@ def test_not_connected(): "HEXA_ENVIRONMENT": "LOCAL_PIPELINE", }, ) -def test_not_connected_missing_url(): - from openhexa.sdk.pipelines import current_run - +def test_not_connected(current_run): + """Test run _connected property in LOCAL_PIPELINE mode.""" assert current_run._connected is False @@ -59,6 +58,7 @@ def test_not_connected_missing_url(): }, ) def test_pipeline_local(): + """Test pipeline _connected property in LOCAL_PIPELINE mode.""" pipeline_func = Mock() pipeline = Pipeline("code", "pipeline", pipeline_func, []) @@ -73,6 +73,7 @@ def test_pipeline_local(): }, ) def test_pipeline_pipeline(): + """Test pipeline _connected property in CLOUD_PIPELINE mode.""" pipeline_func = Mock() pipeline = Pipeline("code", "pipeline", pipeline_func, []) diff --git a/tests/test_parameter.py b/tests/test_parameter.py index 4f6c99f..94f05fd 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -1,36 +1,38 @@ -import pytest +"""Parameter test module.""" + import os +from unittest import mock + +import pytest import stringcase from openhexa.sdk import ( DHIS2Connection, + GCSConnection, IASOConnection, PostgreSQLConnection, - GCSConnection, S3Connection, ) - from openhexa.sdk.pipelines.parameter import ( Boolean, + DHIS2ConnectionType, Float, FunctionWithParameter, + GCSConnectionType, + IASOConnectionType, Integer, InvalidParameterError, Parameter, ParameterValueError, - StringType, PostgreSQLConnectionType, - GCSConnectionType, S3ConnectionType, - IASOConnectionType, - DHIS2ConnectionType, + StringType, parameter, ) -from unittest import mock - def test_parameter_types_normalize(): + """Check normalization or basic types.""" # StringType string_parameter_type = StringType() assert string_parameter_type.normalize("a string") == "a string" @@ -56,6 +58,7 @@ def test_parameter_types_normalize(): def test_parameter_types_validate(): + """Sanity checks for basic types validation.""" # StringType string_parameter_type = StringType() assert string_parameter_type.validate("a string") == "a string" @@ -83,6 +86,7 @@ def test_parameter_types_validate(): def test_validate_postgres_connection(): + """Check PostgreSQL connection validation.""" identifier = "polio-ff3a0d" env_variable_prefix = stringcase.constcase(identifier) host = "https://172.17.0.1" @@ -109,6 +113,7 @@ def test_validate_postgres_connection(): def test_validate_dhis2_connection(): + """Check DHIS2 connection validation.""" identifier = "dhis2-connection-id" env_variable_prefix = stringcase.constcase(identifier) url = "https://test.dhis2.org/" @@ -123,13 +128,14 @@ def test_validate_dhis2_connection(): f"{env_variable_prefix}_PASSWORD": password, }, ): - dhsi2_parameter_type = DHIS2ConnectionType() - assert dhsi2_parameter_type.validate(identifier) == DHIS2Connection(url, username, password) + dhis2_parameter_type = DHIS2ConnectionType() + assert dhis2_parameter_type.validate(identifier) == DHIS2Connection(url, username, password) with pytest.raises(ParameterValueError): - dhsi2_parameter_type.validate(86) + dhis2_parameter_type.validate(86) def test_validate_iaso_connection(): + """Check IASO connection validation.""" identifier = "iaso-connection-id" env_variable_prefix = stringcase.constcase(identifier) url = "https://test.iaso.org/" @@ -151,6 +157,7 @@ def test_validate_iaso_connection(): def test_validate_gcs_connection(): + """Check GCS connection validation.""" identifier = "gcs-connection-id" env_variable_prefix = stringcase.constcase(identifier) service_account_key = "HqQBxH0BAI3zF7kANUNlGg" @@ -170,6 +177,7 @@ def test_validate_gcs_connection(): def test_validate_s3_connection(): + """Check S3 connection validation.""" identifier = "s3-connection-id" env_variable_prefix = stringcase.constcase(identifier) secret_access_key = "HqQBxH0BAI3zF7kANUNlGg" @@ -191,9 +199,10 @@ def test_validate_s3_connection(): def test_parameter_init(): + """Sanity checks for parameter initialization.""" # Wrong type with pytest.raises(InvalidParameterError): - Parameter("arg", type="string") + Parameter("arg", type="string") # NOQA # Wrong code with pytest.raises(InvalidParameterError): @@ -233,6 +242,7 @@ def test_parameter_init(): def test_parameter_validate_single(): + """Base check for single-value validation.""" # required is True by default parameter_1 = Parameter("arg1", type=str) assert parameter_1.validate("a valid string") == "a valid string" @@ -257,6 +267,7 @@ def test_parameter_validate_single(): def test_parameter_validate_multiple(): + """Test multiple values validation rules.""" # required is True by default parameter_1 = Parameter("arg1", type=str, multiple=True) assert parameter_1.validate(["Valid string", "Another valid string"]) == [ @@ -290,6 +301,7 @@ def test_parameter_validate_multiple(): def test_parameter_parameters_spec(): + """Verify that parameter specifications are built properly and have the proper defaults.""" # required is True by default an_parameter = Parameter("arg1", type=str, default="yep") another_parameter = Parameter( @@ -326,6 +338,8 @@ def test_parameter_parameters_spec(): def test_parameter_decorator(): + """Ensure that the @parameter decorator behaves as expected (options and defaults).""" + @parameter("arg1", type=int) @parameter( "arg2", diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index b82aeb7..2dd16dc 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,14 +1,16 @@ +"""Pipeline test module.""" + +import os from unittest.mock import Mock, patch import pytest import stringcase -import os from openhexa.sdk import ( DHIS2Connection, + GCSConnection, IASOConnection, PostgreSQLConnection, - GCSConnection, S3Connection, ) from openhexa.sdk.pipelines.parameter import Parameter, ParameterValueError @@ -16,6 +18,7 @@ def test_pipeline_run_valid_config(): + """Happy path for pipeline run config.""" pipeline_func = Mock() parameter_1 = Parameter("arg1", type=str) parameter_2 = Parameter("arg2", type=str, multiple=True) @@ -28,6 +31,7 @@ def test_pipeline_run_valid_config(): def test_pipeline_run_invalid_config(): + """Verify thatinvalid configuration values raise an exception.""" pipeline_func = Mock() parameter_1 = Parameter("arg1", type=str) pipeline = Pipeline("code", "pipeline", pipeline_func, [parameter_1]) @@ -36,6 +40,7 @@ def test_pipeline_run_invalid_config(): def test_pipeline_run_extra_config(): + """Verify that extra (unexpected) configuration values raise an exception.""" pipeline_func = Mock() parameter_1 = Parameter("arg1", type=str) pipeline = Pipeline("code", "pipeline", pipeline_func, [parameter_1]) @@ -44,6 +49,7 @@ def test_pipeline_run_extra_config(): def test_pipeline_run_connection_dhis2_parameter_config(): + """Ensure that DHIS2 connection parameter values are built properly.""" identifier = "dhis2-connection-id" env_variable_prefix = stringcase.constcase(identifier) url = "https://test.dhis2.org/" @@ -69,6 +75,7 @@ def test_pipeline_run_connection_dhis2_parameter_config(): def test_pipeline_run_connection_iaso_parameter_config(): + """Ensure that IASO connection parameter values are built properly.""" identifier = "iaso-connection-id" env_variable_prefix = stringcase.constcase(identifier) url = "https://test.iaso.org/" @@ -93,6 +100,7 @@ def test_pipeline_run_connection_iaso_parameter_config(): def test_pipeline_run_connection_gcs_parameter_config(): + """Ensure that GCS connection parameter values are built properly.""" identifier = "gcs-connection-id" env_variable_prefix = stringcase.constcase(identifier) service_account_key = "HqQBxH0BAI3zF7kANUNlGg" @@ -115,6 +123,7 @@ def test_pipeline_run_connection_gcs_parameter_config(): def test_pipeline_run_connection_s3_parameter_config(): + """Ensure that S3 connection parameter values are built properly.""" identifier = "s3-connection-id" env_variable_prefix = stringcase.constcase(identifier) secret_access_key = "HqQBxH0BAI3zF7kANUNlGg" @@ -141,6 +150,7 @@ def test_pipeline_run_connection_s3_parameter_config(): def test_pipeline_run_connection_postgres_parameter_config(): + """Ensure that postgreSQL connection parameter values are built properly.""" identifier = "postgres-connection-id" env_variable_prefix = stringcase.constcase(identifier) host = "https://127.0.0.1" @@ -173,6 +183,7 @@ def test_pipeline_run_connection_postgres_parameter_config(): def test_pipeline_parameters_spec(): + """Base checks for parameter specs building.""" pipeline_func = Mock() parameter_1 = Parameter("arg1", type=str) parameter_2 = Parameter("arg2", type=str, multiple=True) diff --git a/tests/test_workspace.py b/tests/test_workspace.py index 154fcbd..41cd7b5 100644 --- a/tests/test_workspace.py +++ b/tests/test_workspace.py @@ -1,22 +1,26 @@ +"""Workspace test module.""" + import os +import re from tempfile import mkdtemp +from unittest import mock import pytest -import re import stringcase -from openhexa.sdk.workspaces.workspace import ConnectionDoesNotExist, workspace -from unittest import mock +from openhexa.sdk.workspaces.workspace import ConnectionDoesNotExist -def test_workspace_files_path(monkeypatch): +def test_workspace_files_path(monkeypatch, workspace): + """Basic checks for the Workspace.files_path() method.""" assert workspace.files_path == "/home/hexa/workspace" monkeypatch.setenv("WORKSPACE_FILES_PATH", "/Users/John/openhexa/project-1/workspace") assert workspace.files_path == "/Users/John/openhexa/project-1/workspace" -def test_workspace_tmp_path(monkeypatch): +def test_workspace_tmp_path(monkeypatch, workspace): + """Basic checks for the Workspace.tmp_path() method.""" assert workspace.tmp_path == "/home/hexa/tmp" mock_tmp_path = mkdtemp() @@ -24,13 +28,15 @@ def test_workspace_tmp_path(monkeypatch): assert workspace.tmp_path == mock_tmp_path -def test_workspace_dhis2_connection_not_exist(): +def test_workspace_dhis2_connection_not_exist(workspace): + """Does not exist test case for DHIS2 connections.""" identifier = "polio-ff3a0d" with pytest.raises(ConnectionDoesNotExist): workspace.dhis2_connection(identifier=identifier) -def test_workspace_dhis2_connection(): +def test_workspace_dhis2_connection(workspace): + """Base test case for DHIS2 connections.""" identifier = "polio-ff3a0d" env_variable_prefix = stringcase.constcase(identifier) url = "https://test.dhis2.org/" @@ -48,17 +54,19 @@ def test_workspace_dhis2_connection(): assert dhis2_connection.url == url assert dhis2_connection.username == username assert dhis2_connection.password == password - assert re.search("password", repr(dhis2_connection)) is None - assert re.search("password", str(dhis2_connection)) is None + assert re.search("password", repr(dhis2_connection), re.IGNORECASE) is None + assert re.search("password", str(dhis2_connection), re.IGNORECASE) is None -def test_workspace_postgresql_connection_not_exist(): +def test_workspace_postgresql_connection_not_exist(workspace): + """Does not exist test case for PostgreSQL connections.""" identifier = "polio-ff3a0d" with pytest.raises(ConnectionDoesNotExist): workspace.postgresql_connection(identifier=identifier) -def test_workspace_postgresql_connection(): +def test_workspace_postgresql_connection(workspace): + """Base test case for PostgreSQL connections.""" identifier = "polio-ff3a0d" env_variable_prefix = stringcase.constcase(identifier) host = "https://172.17.0.1" @@ -84,17 +92,19 @@ def test_workspace_postgresql_connection(): assert postgres_connection.port == int(port) assert postgres_connection.database_name == database_name assert postgres_connection.url == url - assert re.search("password", repr(postgres_connection)) is None - assert re.search("password", str(postgres_connection)) is None + assert re.search("password", repr(postgres_connection), re.IGNORECASE) is None + assert re.search("password", str(postgres_connection), re.IGNORECASE) is None -def test_workspace_S3_connection_not_exist(): +def test_workspace_S3_connection_not_exist(workspace): + """Does not exist test case for S3 connections.""" identifier = "polio-ff3a0d" with pytest.raises(ConnectionDoesNotExist): workspace.s3_connection(identifier=identifier) -def test_workspace_s3_connection(): +def test_workspace_s3_connection(workspace): + """Base test case for S3 connections.""" identifier = "polio-ff3a0d" env_variable_prefix = stringcase.constcase(identifier) secret_access_key = "HqQBxH0BAI3zF7kANUNlGg" @@ -113,19 +123,21 @@ def test_workspace_s3_connection(): assert s3_connection.secret_access_key == secret_access_key assert s3_connection.access_key_id == access_key_id assert s3_connection.bucket_name == bucket_name - assert re.search("secret_access_key", repr(s3_connection)) is None - assert re.search("secret_access_key", str(s3_connection)) is None - assert re.search("access_key_id", repr(s3_connection)) is None - assert re.search("access_key_id", str(s3_connection)) is None + assert re.search("secret_access_key", repr(s3_connection), re.IGNORECASE) is None + assert re.search("secret_access_key", str(s3_connection), re.IGNORECASE) is None + assert re.search("access_key_id", repr(s3_connection), re.IGNORECASE) is None + assert re.search("access_key_id", str(s3_connection), re.IGNORECASE) is None -def test_workspace_gcs_connection_not_exist(): +def test_workspace_gcs_connection_not_exist(workspace): + """Does not exist test case for GCS connections.""" identifier = "polio-ff3a0d" with pytest.raises(ConnectionDoesNotExist): workspace.gcs_connection(identifier=identifier) -def test_workspace_gcs_connection(): +def test_workspace_gcs_connection(workspace): + """Base test case for GCS connections.""" identifier = "polio-ff3a0d" env_variable_prefix = stringcase.constcase(identifier) service_account_key = "HqQBxH0BAI3zF7kANUNlGg" @@ -141,17 +153,19 @@ def test_workspace_gcs_connection(): s3_connection = workspace.gcs_connection(identifier=identifier) assert s3_connection.service_account_key == service_account_key assert s3_connection.bucket_name == bucket_name - assert re.search("service_account_key", repr(s3_connection)) is None - assert re.search("service_account_key", str(s3_connection)) is None + assert re.search("service_account_key", repr(s3_connection), re.IGNORECASE) is None + assert re.search("service_account_key", str(s3_connection), re.IGNORECASE) is None -def test_workspace_iaso_connection_not_exist(): +def test_workspace_iaso_connection_not_exist(workspace): + """Does not exist test case for IASO connections.""" identifier = "iaso-account" with pytest.raises(ConnectionDoesNotExist): workspace.iaso_connection(identifier=identifier) -def test_workspace_iaso_connection(): +def test_workspace_iaso_connection(workspace): + """Base test case for IASO connections.""" identifier = "iaso-account" env_variable_prefix = stringcase.constcase(identifier) url = "https://test.iaso.org/" @@ -170,11 +184,19 @@ def test_workspace_iaso_connection(): assert iaso_connection.url == url assert iaso_connection.username == username assert iaso_connection.password == password - assert re.search("password", repr(iaso_connection)) is None - assert re.search("password", str(iaso_connection)) is None + assert re.search("password", repr(iaso_connection), re.IGNORECASE) is None + assert re.search("password", str(iaso_connection), re.IGNORECASE) is None + + +def test_workspace_custom_connection_not_exist(workspace): + """Does not exist test case for custom connections.""" + identifier = "custom-stuff" + with pytest.raises(ConnectionDoesNotExist): + workspace.custom_connection(identifier=identifier) -def test_workspace_custom_connection(): +def test_workspace_custom_connection(workspace): + """Base test case for custom connections.""" identifier = "my_connection" env_variable_prefix = stringcase.constcase(identifier) username = "kaggle_username" @@ -190,11 +212,12 @@ def test_workspace_custom_connection(): custom_connection = workspace.custom_connection(identifier=identifier) assert custom_connection.username == username assert custom_connection.password == password - assert re.search(f"{env_variable_prefix}_PASSWORD", repr(custom_connection)) is None - assert re.search(f"{env_variable_prefix}_PASSWORD", str(custom_connection)) is None + assert re.search("username", repr(custom_connection), re.IGNORECASE) is None + assert re.search("password", str(custom_connection), re.IGNORECASE) is None -def test_connection_by_slug_warning(): +def test_connection_by_slug_warning(workspace): + """Ensure that using the slug keyword argument when getting a connection generates a deprecation warning.""" identifier = "polio-ff3a0d" env_variable_prefix = stringcase.constcase(identifier) url = "https://test.dhis2.org/" @@ -214,7 +237,8 @@ def test_connection_by_slug_warning(): assert workspace.dhis2_connection(identifier=identifier).url == url -def test_connection_various_case(): +def test_connection_various_case(workspace): + """Ensure that identifiers used when getting connections are case-insensitive.""" env_variable_prefix = stringcase.constcase("polio-123") url = "https://test.dhis2.org/" username = "dhis2"