diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index e5c0f3f..4fdeed2 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -1,11 +1,7 @@
repos:
- - repo: https://github.com/astral-sh/ruff-pre-commit
- # Ruff version.
- rev: v0.0.280
- hooks:
- - id: ruff
- - repo: https://github.com/psf/black
- rev: 23.7.0
- hooks:
- - id: black
- language_version: python3
+- repo: https://github.com/astral-sh/ruff-pre-commit
+ rev: v0.1.8
+ hooks:
+ - id: ruff
+ args: [ --fix, --exit-non-zero-on-fix ]
+ - id: ruff-format
\ No newline at end of file
diff --git a/README.md b/README.md
index 21f0f70..9ece4b0 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,46 @@
-# OpenHexa Python SDK
+
+

+
+
+ Open-source Data integration platform
+
+
+
+
+
+
-The OpenHexa Python SDK is a tool that helps you write code for the OpenHexa platform.
+OpenHEXA Python SDK
+===================
-It is particularly useful to write OpenHexa data pipelines, but can also be used in the OpenHexa notebooks environment.
+OpenHEXA is an open-source data integration platform developed by [Bluesquare](https://bluesquarehub.com).
-## Quickstart
+Its goal is to facilitate data integration and analysis workflows, in particular in the context of public health
+projects.
-### Writing and deploying pipelines
+Please refer to the [OpenHEXA wiki](https://github.com/BLSQ/openhexa/wiki/Home) for more information about OpenHEXA.
+
+This repository contains the code of the OpenHEXA SDK, a library allows you to write code for the OpenHEXA platform.
+It is particularly useful to write OpenHEXA data pipelines, but can also be used in the OpenHEXA notebooks environment.
+
+The OpenHEXA wiki has a section dedicated to the SDK:
+[Using the OpenHEXA SDK](https://github.com/BLSQ/openhexa/wiki/Using-the-OpenHEXA-SDK).
+
+For more information about the technical aspects of OpenHEXA, you might be interested in the two following wiki pages:
+
+- [Installing OpenHEXA](https://github.com/BLSQ/openhexa/wiki/Installation-instructions)
+- [Technical Overview](https://github.com/BLSQ/openhexa/wiki/Technical-overview)
+
+Requirements
+------------
+
+The OpenHEXA SDK requires Python version 3.9 or newer, but it is not yet compatible with Python 3.12 or later versions.
+
+If you want to be able to run pipeline in a containerized environment on your machine, you will need
+[Docker](https://www.docker.com/).
+
+Quickstart
+----------
Here's a super minimal example to get you started. First, create a new directory and a virtual environment:
@@ -17,12 +51,15 @@ python -m venv venv
source venv/bin/activate
```
-You can then install the OpenHexa SDK:
+You can then install the OpenHEXA SDK:
```shell
pip install --upgrade openhexa.sdk
```
+💡New OpenHEXA SDK versions are released on a regular basis. Don't forget to update your local installations with
+`pip install --upgrade` from times to times!
+
Now that the SDK is installed withing your virtual environmentYou can now use the `openhexa` CLI utility to create
a new pipeline:
@@ -30,8 +67,8 @@ a new pipeline:
openhexa pipelines init "My awesome pipeline"
```
-Great! As you can see in the console output, the OpenHexa CLI has created a new directory, which contains the basic
-structure required for an OpenHexa pipeline. You can now `cd` in the new pipeline directory and run the pipeline:
+Great! As you can see in the console output, the OpenHEXA CLI has created a new directory, which contains the basic
+structure required for an OpenHEXA pipeline. You can now `cd` in the new pipeline directory and run the pipeline:
```shell
cd my_awesome_pipeline
@@ -41,11 +78,11 @@ python pipeline.py
Congratulations! You have successfully run your first pipeline locally.
If you inspect the actual pipeline code, you will see that it doesn't do a lot of things, but it is still a perfectly
-valid OpenHexa pipeline.
+valid OpenHEXA pipeline.
-Let's publish to an actual OpenHexa workspace so that it can run online.
+Let's publish to an actual OpenHEXA workspace so that it can run online.
-Using the OpenHexa web interface, within a workspace, navigate to the Pipelines tab and click on "Create".
+Using the OpenHEXA web interface, within a workspace, navigate to the Pipelines tab and click on "Create".
Copy the command displayed in the popup in your terminal:
@@ -62,20 +99,17 @@ openhexa pipelines push
```
As it is the first time, the CLI will ask you to confirm the creation operation. After confirmation the console will
-output the link to the pipeline screen in the OpenHexa interface.
-
-You can now open the link and run the pipeline using the OpenHexa web interface.
+output the link to the pipeline screen in the OpenHEXA interface.
-### Using the SDK in the notebooks environment
+You can now open the link and run the pipeline using the OpenHEXA web interface.
-TBC
+Contributing
+------------
-## Contributing
+The following sections explain how you can set up a local development environment if you want to participate to the
+development of the SDK.
-The following sections explain how you can setup a local development environment if you want to participate to the
-development of the SDK
-
-### Development setup
+### SDK development setup
Install the SDK in editable mode:
@@ -85,7 +119,13 @@ source venv/bin/activate # Activate the venv
pip install -e ".[dev]" # Necessary to be able to run the openhexa CLI
```
-### Using a local installation of the OpenHexa backend to run pipelines
+### Using a local installation of OpenHEXA to run pipelines
+
+While it is possible to run pipelines locally using only the SDK, if you want to run OpenHEXA in a more realistic
+setting you will need to install the OpenHEXA app and frontend components. Please refer to the
+[installation instructions](https://github.com/BLSQ/openhexa/wiki/Installation-instructions) for more information.
+
+You can then configure the OpenHEXA CLI to connect to your local backend:
```shell
openhexa config set_url http://localhost:8000
@@ -95,7 +135,7 @@ Notes: you can monitor the status of your pipelines using http://localhost:8000/
### Running the tests
-Run the tests using pytest:
+You can run the tests using pytest:
```shell
pytest
diff --git a/examples/pipelines/logistic_stats/pipeline.py b/examples/pipelines/logistic_stats/pipeline.py
index 513823f..0933a69 100644
--- a/examples/pipelines/logistic_stats/pipeline.py
+++ b/examples/pipelines/logistic_stats/pipeline.py
@@ -1,3 +1,5 @@
+"""Simple module for a sample logistic pipeline."""
+
import json
import typing
from io import BytesIO
@@ -26,6 +28,7 @@
)
@parameter("oul", name="Organisation unit level", type=int, default=2)
def logistic_stats(deg: str, periods: str, oul: int):
+ """Run a basic logistic stats pipeline."""
dhis2_data = dhis2_download(deg, periods, oul)
gadm_data = gadm_download()
worldpop_data = worldpop_download()
@@ -34,7 +37,8 @@ def logistic_stats(deg: str, periods: str, oul: int):
@logistic_stats.task
-def dhis2_download(data_element_group: str, periods: str, org_unit_level: int) -> typing.Dict[str, typing.Any]:
+def dhis2_download(data_element_group: str, periods: str, org_unit_level: int) -> dict[str, typing.Any]:
+ """Download DHIS2 data."""
connection = workspace.dhis2_connection("dhis2-play")
base_url = f"{connection.url}/api"
session = requests.Session()
@@ -61,6 +65,7 @@ def dhis2_download(data_element_group: str, periods: str, org_unit_level: int) -
@logistic_stats.task
def gadm_download():
+ """Download administrative boundaries data from UCDavis."""
url = "https://geodata.ucdavis.edu/gadm/gadm4.1/gpkg/gadm41_SLE.gpkg"
r = requests.get(url, timeout=30)
@@ -69,6 +74,7 @@ def gadm_download():
@logistic_stats.task
def worldpop_download():
+ """Download population data from worldpop.org."""
base_url = "https://data.worldpop.org/"
url = f"{base_url}GIS/Population/Global_2000_2020_Constrained/2020/maxar_v1/SLE/sle_ppp_2020_UNadj_constrained.tif"
r = requests.get(url)
@@ -77,7 +83,8 @@ def worldpop_download():
@logistic_stats.task
-def model(dhis2_data: typing.Dict[str, typing.Any], gadm_data, worldpop_data):
+def model(dhis2_data: dict[str, typing.Any], gadm_data, worldpop_data):
+ """Build a basic data model."""
# Load DHIS2 data
dhis2_df = pd.DataFrame(dhis2_data["rows"], columns=[h["column"] for h in dhis2_data["headers"]])
dhis2_df = dhis2_df.rename(columns={"Data": "Data element id", "Organisation unit": "Organisation unit id"})
diff --git a/examples/pipelines/logistic_stats/tests/__init__.py b/examples/pipelines/logistic_stats/tests/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/examples/pipelines/simple_io/pipeline.py b/examples/pipelines/simple_io/pipeline.py
index f847316..8a52760 100644
--- a/examples/pipelines/simple_io/pipeline.py
+++ b/examples/pipelines/simple_io/pipeline.py
@@ -1,3 +1,5 @@
+"""Simple module for a sample IO pipeline."""
+
import json
import pandas as pd
@@ -9,6 +11,7 @@
@pipeline("simple-io", name="Simple IO")
def simple_io():
+ """Run a simple IO pipeline."""
# Read and write from/to workspace files
raw_files_data = load_files_data()
transform_and_write_files_data(raw_files_data)
@@ -23,6 +26,7 @@ def simple_io():
@simple_io.task
def load_files_data():
+ """Load data from workspace filesystem."""
current_run.log_info("Loading files data...")
return pd.read_csv(f"{workspace.files_path}/raw.csv")
@@ -30,6 +34,7 @@ def load_files_data():
@simple_io.task
def transform_and_write_files_data(raw_data: pd.DataFrame):
+ """Simulate a transformation on the provided dataframe and write data to workspace filesystem."""
current_run.log_info("Transforming files data...")
transformed_data = raw_data.copy()
@@ -41,6 +46,7 @@ def transform_and_write_files_data(raw_data: pd.DataFrame):
@simple_io.task
def load_data_from_postgresql() -> pd.DataFrame:
+ """Perform a simple SELECT query in the workspace database."""
current_run.log_info("Loading Postgres data...")
engine = create_engine(workspace.database_url)
@@ -50,6 +56,7 @@ def load_data_from_postgresql() -> pd.DataFrame:
@simple_io.task
def transform_and_write_sql_data(raw_data: pd.DataFrame):
+ """Simulate a transform operation on the provided data and load it in the workspace database."""
current_run.log_info("Transforming postgres data...")
engine = create_engine(workspace.database_url)
@@ -61,6 +68,7 @@ def transform_and_write_sql_data(raw_data: pd.DataFrame):
@simple_io.task
def load_dhis2_data():
+ """Load data from DHIS2."""
current_run.log_info("Loading DHIS2 data...")
connection = workspace.dhis2_connection("dhis2-play")
diff --git a/openhexa/cli/__init__.py b/openhexa/cli/__init__.py
index 7272735..eeb05c4 100644
--- a/openhexa/cli/__init__.py
+++ b/openhexa/cli/__init__.py
@@ -1,3 +1,5 @@
+"""CLI package."""
+
from .cli import app
__all__ = ["app"]
diff --git a/openhexa/cli/api.py b/openhexa/cli/api.py
index e9b0618..4c46298 100644
--- a/openhexa/cli/api.py
+++ b/openhexa/cli/api.py
@@ -1,8 +1,11 @@
+"""Collection of functions that interacts with the OpenHEXA API."""
+
import base64
-import configparser
import enum
import io
import os
+import typing
+from configparser import ConfigParser
from importlib.metadata import version
from pathlib import Path
from zipfile import ZipFile
@@ -16,20 +19,29 @@
class InvalidDefinitionError(Exception):
+ """Raised whenever pipeline parameters and/or pipeline options are incompatible."""
+
pass
-class PipelineErrorEnum(enum.Enum):
+class PipelineDefinitionErrorCode(enum.Enum):
+ """Enumeration of possible pipeline definition error codes."""
+
PIPELINE_DOES_NOT_SUPPORT_PARAMETERS = "PIPELINE_DOES_NOT_SUPPORT_PARAMETERS"
INVALID_TIMEOUT_VALUE = "INVALID_TIMEOUT_VALUE"
-def is_debug(config: configparser.ConfigParser):
+def is_debug(config: ConfigParser) -> bool:
+ """Determine whether the provided configuration has the debug flag."""
return config.getboolean("openhexa", "debug", fallback=False)
def open_config():
- config = configparser.ConfigParser()
+ """Open the local configuration file using configparser.
+
+ A default configuration file will be generated if the file does not exist.
+ """
+ config = ConfigParser()
if os.path.exists(CONFIGFILE_PATH):
config.read(CONFIGFILE_PATH)
else:
@@ -44,12 +56,14 @@ def open_config():
return config
-def save_config(config):
+def save_config(config: ConfigParser):
+ """Save the provided configparser local configuration to disk."""
with open(CONFIGFILE_PATH, "w") as configfile:
config.write(configfile)
def graphql(config, query: str, variables=None, token=None):
+ """Perform a GraphQL request."""
url = config["openhexa"]["url"] + "/graphql/"
if token is None:
current_workspace = config["openhexa"]["current_workspace"]
@@ -89,6 +103,7 @@ def graphql(config, query: str, variables=None, token=None):
def get_workspace(config, slug: str, token: str):
+ """Get a single workspace."""
return graphql(
config,
"""
@@ -104,7 +119,8 @@ def get_workspace(config, slug: str, token: str):
)["workspace"]
-def get_pipelines(config):
+def list_pipelines(config):
+ """List all pipelines in the workspace."""
data = graphql(
config,
"""
@@ -126,7 +142,8 @@ def get_pipelines(config):
return data["pipelines"]["items"]
-def get_pipeline(config, pipeline_code: str):
+def get_pipeline(config, pipeline_code: str) -> dict[str, typing.Any]:
+ """Get a single pipeline."""
data = graphql(
config,
"""
@@ -149,6 +166,7 @@ def get_pipeline(config, pipeline_code: str):
def create_pipeline(config, pipeline_code: str, pipeline_name: str):
+ """Create a pipeline using the API."""
data = graphql(
config,
"""
@@ -179,7 +197,8 @@ def create_pipeline(config, pipeline_code: str, pipeline_name: str):
return data["createPipeline"]["pipeline"]
-def delete_pipeline(config, id: str):
+def delete_pipeline(config, pipeline_id: str):
+ """Delete a single pipeline."""
data = graphql(
config,
"""
@@ -190,7 +209,7 @@ def delete_pipeline(config, id: str):
}
}
""",
- {"input": {"id": id}},
+ {"input": {"id": pipeline_id}},
)
if not data["deletePipeline"]["success"]:
@@ -200,7 +219,7 @@ def delete_pipeline(config, id: str):
def ensure_is_pipeline_dir(pipeline_path: str):
- # Ensure that there is a pipeline.py file in the directory
+ """Ensure that there is a pipeline.py file in the directory."""
if not os.path.isdir(pipeline_path):
raise Exception(f"Path {pipeline_path} is not a directory")
if not os.path.exists(pipeline_path):
@@ -211,11 +230,15 @@ def ensure_is_pipeline_dir(pipeline_path: str):
return True
-def upload_pipeline(config, pipeline_directory_path: str):
+def upload_pipeline(config, pipeline_directory_path: typing.Union[str, Path]):
+ """Upload the pipeline contained in the provided directory using the GraphQL API.
+
+ The pipeline code will be zipped and base64-encoded before being sent to the backend.
+ """
pipeline = import_pipeline(pipeline_directory_path)
directory = Path(os.path.abspath(pipeline_directory_path))
- zipFile = io.BytesIO(b"")
+ zip_file = io.BytesIO(b"")
if is_debug(config):
click.echo("Generating ZIP file:")
@@ -228,7 +251,7 @@ def upload_pipeline(config, pipeline_directory_path: str):
if env_vars.get("WORKSPACE_FILES_PATH") and Path(env_vars["WORKSPACE_FILES_PATH"]) not in excluded_paths:
excluded_paths.append(Path(env_vars["WORKSPACE_FILES_PATH"]))
- with ZipFile(zipFile, "w") as zipObj:
+ with ZipFile(zip_file, "w") as zipObj:
for path in directory.glob("**/*"):
if path.name == "python":
# We are in a virtual environment
@@ -252,15 +275,15 @@ def upload_pipeline(config, pipeline_directory_path: str):
click.echo(f"\t{file_path.name}")
zipObj.write(file_path, file_path.relative_to(directory))
- zipFile.seek(0)
+ zip_file.seek(0)
if is_debug(config):
- # Write zipFile to disk for debugging
+ # Write zip_file to disk for debugging
with open("pipeline.zip", "wb") as debug_file:
- debug_file.write(zipFile.read())
- zipFile.seek(0)
+ debug_file.write(zip_file.read())
+ zip_file.seek(0)
- base64_content = base64.b64encode(zipFile.read()).decode("ascii")
+ base64_content = base64.b64encode(zip_file.read()).decode("ascii")
data = graphql(
config,
@@ -285,11 +308,12 @@ def upload_pipeline(config, pipeline_directory_path: str):
)
if not data["uploadPipeline"]["success"]:
- if PipelineErrorEnum.PIPELINE_DOES_NOT_SUPPORT_PARAMETERS.value in data["uploadPipeline"]["errors"]:
+ if PipelineDefinitionErrorCode.PIPELINE_DOES_NOT_SUPPORT_PARAMETERS.value in data["uploadPipeline"]["errors"]:
raise InvalidDefinitionError(
- "Cannot push a new version : this pipeline has a schedule and the new version is not schedulable (all parameters must be optional or have default values)."
+ "Cannot push a new version : this pipeline has a schedule and the new version cannot be scheduled "
+ "(all parameters must be optional or have default values)."
)
- elif PipelineErrorEnum.INVALID_TIMEOUT_VALUE.value in data["uploadPipeline"]["errors"]:
+ elif PipelineDefinitionErrorCode.INVALID_TIMEOUT_VALUE.value in data["uploadPipeline"]["errors"]:
raise InvalidDefinitionError(
"Timeout value is invalid : ensure that it's no negative and inferior to 12 hours."
)
diff --git a/openhexa/cli/cli.py b/openhexa/cli/cli.py
index 67b5131..9464eb3 100644
--- a/openhexa/cli/cli.py
+++ b/openhexa/cli/cli.py
@@ -1,3 +1,5 @@
+"""CLI module, with click."""
+
import base64
import json
import sys
@@ -13,14 +15,13 @@
delete_pipeline,
ensure_is_pipeline_dir,
get_pipeline,
- get_pipelines,
get_workspace,
is_debug,
+ list_pipelines,
open_config,
save_config,
upload_pipeline,
)
-from openhexa.cli.utils import terminate
from openhexa.sdk.pipelines import get_local_workspace_config, import_pipeline
@@ -29,9 +30,7 @@
@click.version_option(version("openhexa.sdk"))
@click.pass_context
def app(ctx, debug):
- """
- OpenHexa CLI
- """
+ """OpenHEXA CLI."""
# ensure that ctx.obj exists and is a dict (in case `cli()` is called
# by means other than the `if` block below)
ctx.ensure_object(dict)
@@ -44,9 +43,7 @@ def app(ctx, debug):
@app.group(invoke_without_command=True)
@click.pass_context
def workspaces(ctx):
- """
- Manage workspaces (add workspace, remove workspace, list workspaces, activate a workspace)
- """
+ """Manage workspaces (add workspace, remove workspace, list workspaces, activate a workspace)."""
if ctx.invoked_subcommand is None:
ctx.forward(workspaces_list)
@@ -55,9 +52,7 @@ def workspaces(ctx):
@click.argument("slug")
@click.option("--token", prompt=True, hide_input=True, confirmation_prompt=False)
def workspaces_add(slug, token):
- """
- Add a workspace to the configuration and activate it. The access token is required to access the workspace.
- """
+ """Add a workspace to the configuration and activate it. The access token is required to access the workspace."""
user_config = open_config()
if slug in user_config["workspaces"]:
click.echo(f"Workspace {slug} already exists. We will only update its token.")
@@ -83,10 +78,7 @@ def workspaces_add(slug, token):
@workspaces.command(name="activate")
@click.argument("slug")
def workspaces_activate(slug):
- """
- Activate a workspace that is already in the configuration. The activated workspace will be used for the 'pipelines' commands.
- """
-
+ """Activate a workspace that is already in the configuration. The activated workspace will be used for the 'pipelines' commands."""
user_config = open_config()
if slug not in user_config["workspaces"]:
click.echo(f"Workspace {slug} does not exist on {user_config['openhexa']['url']}. Available workspaces:")
@@ -100,9 +92,7 @@ def workspaces_activate(slug):
@workspaces.command(name="list")
def workspaces_list():
- """
- List the workspaces in the configuration.
- """
+ """List the workspaces in the configuration."""
user_config = open_config()
click.echo("Workspaces:")
@@ -119,8 +109,7 @@ def workspaces_list():
@workspaces.command(name="rm")
@click.argument("slug")
def workspaces_rm(slug):
- """
- Remove a workspace from the configuration.
+ """Remove a workspace from the configuration.
SLUG is the slug of the workspace to remove from the configuration.
"""
@@ -143,10 +132,7 @@ def workspaces_rm(slug):
@app.group(invoke_without_command=True)
@click.pass_context
def config(ctx):
- """
- Manage configuration of the CLI.
- """
-
+ """Manage configuration of the CLI."""
if ctx.invoked_subcommand is None:
user_config = open_config()
click.echo("Debug: " + ("True" if is_debug(user_config) else "False"))
@@ -162,10 +148,7 @@ def config(ctx):
@config.command(name="set_url")
@click.argument("url")
def config_set_url(url):
- """
- Set the URL of the backend.
-
- """
+ """Set the URL of the backend."""
user_config = open_config()
user_config["openhexa"].update({"url": url})
save_config(user_config)
@@ -175,9 +158,7 @@ def config_set_url(url):
@app.group(invoke_without_command=True)
@click.pass_context
def pipelines(ctx):
- """
- Manage pipelines (list pipelines, push a pipeline to the backend)
- """
+ """Manage pipelines (list pipelines, push a pipeline to the backend)."""
if ctx.invoked_subcommand is None:
ctx.forward(pipelines_list)
@@ -185,10 +166,7 @@ def pipelines(ctx):
@pipelines.command("init")
@click.argument("name", type=str)
def pipelines_init(name: str):
- """
- Initialize a new pipeline in a fresh directory.
- """
-
+ """Initialize a new pipeline in a fresh directory."""
new_pipeline_directory_name = stringcase.snakecase(name.lower())
new_pipeline_path = Path.cwd() / Path(new_pipeline_directory_name)
if new_pipeline_path.exists():
@@ -201,16 +179,16 @@ def pipelines_init(name: str):
# Load samples
sample_directory_path = Path(__file__).parent / Path("skeleton")
- with open(sample_directory_path / Path(".gitignore"), "r") as sample_ignore_file:
+ with open(sample_directory_path / Path(".gitignore")) as sample_ignore_file:
sample_ignore_content = sample_ignore_file.read()
- with open(sample_directory_path / Path("pipeline.py"), "r") as sample_pipeline_file:
+ with open(sample_directory_path / Path("pipeline.py")) as sample_pipeline_file:
sample_pipeline_content = (
sample_pipeline_file.read()
.replace("skeleton-pipeline-code", stringcase.spinalcase(name.lower()))
.replace("skeleton_pipeline_name", stringcase.snakecase(name.lower()))
.replace("Skeleton pipeline name", name)
)
- with open(sample_directory_path / Path("workspace.yaml"), "r") as sample_workspace_file:
+ with open(sample_directory_path / Path("workspace.yaml")) as sample_workspace_file:
sample_workspace_content = sample_workspace_file.read()
# Create directory
@@ -237,12 +215,10 @@ def pipelines_init(name: str):
@pipelines.command("push")
@click.argument("path", default=".", type=click.Path(exists=True, file_okay=False, dir_okay=True))
def pipelines_push(path: str):
- """
- Push a pipeline to the backend. If the pipeline already exists, it will be updated otherwise it will be created.
+ """Push a pipeline to the backend. If the pipeline already exists, it will be updated otherwise it will be created.
PATH is the path to the pipeline file.
"""
-
user_config = open_config()
try:
workspace = user_config["openhexa"]["current_workspace"]
@@ -264,7 +240,7 @@ def pipelines_push(path: str):
raise e
sys.exit(1)
else:
- workspace_pipelines = get_pipelines(user_config)
+ workspace_pipelines = list_pipelines(user_config)
if is_debug(user_config):
click.echo(workspace_pipelines)
@@ -291,17 +267,17 @@ def pipelines_push(path: str):
"api", "app"
)
click.echo(
- f"Done! You can view the pipeline in OpenHexa on {click.style(url, fg='bright_blue', underline=True)}"
+ f"Done! You can view the pipeline in OpenHEXA on {click.style(url, fg='bright_blue', underline=True)}"
)
except InvalidDefinitionError as e:
- terminate(
+ _terminate(
f'Pipeline definition is invalid: "{e}"',
err=True,
exception=e,
debug=is_debug(user_config),
)
except Exception as e:
- terminate(
+ _terminate(
f'Error while importing pipeline: "{e}"',
err=True,
exception=e,
@@ -312,10 +288,7 @@ def pipelines_push(path: str):
@pipelines.command("delete")
@click.argument("code", type=str)
def pipelines_delete(code: str):
- """
- Delete a pipeline and all his versions.
- """
-
+ """Delete a pipeline and all his versions."""
user_config = open_config()
try:
workspace = user_config["openhexa"]["current_workspace"]
@@ -351,7 +324,7 @@ def pipelines_delete(code: str):
click.echo(f"Pipeline {click.style(code, bold=True)} deleted.")
except Exception as e:
- terminate(
+ _terminate(
f'Error while deleting pipeline: "{e}"',
err=True,
exception=e,
@@ -376,9 +349,7 @@ def pipelines_run(
config_str: str = "{}",
config_file: click.File = None,
):
- """
- Run a pipeline locally.
- """
+ """Run a pipeline locally."""
from subprocess import Popen
user_config = open_config()
@@ -440,9 +411,7 @@ def pipelines_run(
@pipelines.command("list")
def pipelines_list():
- """
- List all the remote pipelines of the current workspace.
- """
+ """List all the remote pipelines of the current workspace."""
user_config = open_config()
workspace = user_config["openhexa"]["current_workspace"]
@@ -450,15 +419,22 @@ def pipelines_list():
click.echo("No workspace activated", err=True)
sys.exit(1)
- workspace_pipelines = get_pipelines(user_config)
+ workspace_pipelines = list_pipelines(user_config)
if len(workspace_pipelines) == 0:
click.echo(f"No pipelines in workspace {workspace}")
return
click.echo("Pipelines:")
for pipeline in workspace_pipelines:
- version = pipeline["currentVersion"].get("number")
- if version:
- version = f"v{version}"
+ current_version = pipeline["currentVersion"].get("number")
+ if current_version is not None:
+ current_version = f"v{current_version}"
else:
- version = "N/A"
- click.echo(f"* {pipeline['code']} - {pipeline['name']} ({version})")
+ current_version = "N/A"
+ click.echo(f"* {pipeline['code']} - {pipeline['name']} ({current_version})")
+
+
+def _terminate(message: str, exception: Exception = None, err: bool = False, debug: bool = False):
+ click.echo(message, err=err)
+ if debug and exception:
+ raise exception
+ sys.exit(1)
diff --git a/openhexa/cli/skeleton/pipeline.py b/openhexa/cli/skeleton/pipeline.py
index 299de50..7b7156a 100644
--- a/openhexa/cli/skeleton/pipeline.py
+++ b/openhexa/cli/skeleton/pipeline.py
@@ -1,14 +1,21 @@
+"""Template for newly generated pipelines."""
+
from openhexa.sdk import current_run, pipeline
@pipeline("skeleton-pipeline-code", name="Skeleton pipeline name")
def skeleton_pipeline_name():
+ """Write your pipeline orchestration here.
+
+ Pipeline functions should only call tasks and should never perform IO operations or expensive computations.
+ """
count = task_1()
task_2(count)
@skeleton_pipeline_name.task
def task_1():
+ """Put some data processing code here."""
current_run.log_info("In task 1...")
return 42
@@ -16,6 +23,7 @@ def task_1():
@skeleton_pipeline_name.task
def task_2(count):
+ """Put some data processing code here."""
current_run.log_info(f"In task 2... count is {count}")
diff --git a/openhexa/cli/utils.py b/openhexa/cli/utils.py
deleted file mode 100644
index e29c890..0000000
--- a/openhexa/cli/utils.py
+++ /dev/null
@@ -1,10 +0,0 @@
-import sys
-
-import click
-
-
-def terminate(message: str, exception: Exception = None, err: bool = False, debug: bool = False):
- click.echo(message, err=err)
- if debug and exception:
- raise exception
- sys.exit(1)
diff --git a/openhexa/sdk/__init__.py b/openhexa/sdk/__init__.py
index d7d1a71..d67e936 100644
--- a/openhexa/sdk/__init__.py
+++ b/openhexa/sdk/__init__.py
@@ -1,6 +1,8 @@
+"""SDK package."""
+
from .pipelines import current_run, parameter, pipeline
from .workspaces import workspace
-from .workspaces.connection import DHIS2Connection, IASOConnection, PostgreSQLConnection, GCSConnection, S3Connection
+from .workspaces.connection import DHIS2Connection, GCSConnection, IASOConnection, PostgreSQLConnection, S3Connection
__all__ = [
"workspace",
diff --git a/openhexa/sdk/datasets/__init__.py b/openhexa/sdk/datasets/__init__.py
index 7dab1bd..20efeca 100644
--- a/openhexa/sdk/datasets/__init__.py
+++ b/openhexa/sdk/datasets/__init__.py
@@ -1,3 +1,10 @@
+"""Datasets package.
+
+See https://github.com/BLSQ/openhexa/wiki/User-manual#datasets and
+https://github.com/BLSQ/openhexa/wiki/Using-the-OpenHEXA-SDK#working-with-datasets for more information about OpenHEXA
+dataset.
+"""
+
from .dataset import Dataset, DatasetFile
__all__ = ["Dataset", "DatasetFile"]
diff --git a/openhexa/sdk/datasets/dataset.py b/openhexa/sdk/datasets/dataset.py
index 2ce511b..f22b87c 100644
--- a/openhexa/sdk/datasets/dataset.py
+++ b/openhexa/sdk/datasets/dataset.py
@@ -1,3 +1,9 @@
+"""Dataset-related classes and functions.
+
+See https://github.com/BLSQ/openhexa/wiki/User-manual#datasets and
+https://github.com/BLSQ/openhexa/wiki/Using-the-OpenHEXA-SDK#working-with-datasets for more information about datasets.
+"""
+
import mimetypes
import typing
from os import PathLike
@@ -9,6 +15,8 @@
class DatasetFile:
+ """Represent a single file within a dataset. Files are attached to dataset through versions."""
+
_download_url = None
version = None
@@ -21,12 +29,14 @@ def __init__(self, version: any, id: str, uri: str, filename: str, content_type:
self.created_at = created_at
def read(self):
+ """Download the file content and return it."""
response = requests.get(self.download_url, stream=True)
response.raise_for_status()
return response.content
@property
def download_url(self):
+ """Build and return a pre-signed URL for the file."""
if self._download_url is None:
response = graphql(
"""
@@ -46,10 +56,13 @@ def download_url(self):
return self._download_url
def __repr__(self) -> str:
+ """Safe representation of the dataset file."""
return f""
class VersionsIterator(Iterator):
+ """Custom iterator class to iterate versions using our GraphQL API."""
+
def __init__(self, dataset: any, per_page: int = 10):
super().__init__(per_page=per_page)
@@ -90,6 +103,8 @@ def _next_page(self):
class VersionFilesIterator(Iterator):
+ """Custom iterator class to iterate version files using our GraphQL API."""
+
def __init__(self, version: any, per_page: int = 20):
super().__init__(per_page=per_page)
self.item_to_value = lambda x: DatasetFile(
@@ -137,6 +152,8 @@ def _next_page(self):
class DatasetVersion:
+ """Dataset files are not directly attached to a dataset, but rather to a version."""
+
dataset = None
def __init__(self, dataset: any, id: str, name: str, created_at: str):
@@ -147,11 +164,13 @@ def __init__(self, dataset: any, id: str, name: str, created_at: str):
@property
def files(self):
+ """Build and return an iterator of files for this version."""
if self.id is None:
raise ValueError("This dataset version does not have an id.")
return VersionFilesIterator(version=self, per_page=50)
- def get_file(self, filename: str):
+ def get_file(self, filename: str) -> DatasetFile:
+ """Get a file by name."""
data = graphql(
"""
query getDatasetFile($versionId: ID!, $filename: String!) {
@@ -171,7 +190,7 @@ def get_file(self, filename: str):
file = data["datasetVersion"]["fileByName"]
if file is None:
- return None
+ raise FileExistsError(f"The file {filename} does not exist for version {self}")
return DatasetFile(
version=self,
@@ -186,7 +205,8 @@ def add_file(
self,
source: typing.Union[str, PathLike[str], typing.IO],
filename: typing.Optional[str] = None,
- ):
+ ) -> DatasetFile:
+ """Create a new dataset file and add it to the dataset version."""
mime_type = None
if isinstance(source, (str, PathLike)):
path = Path(source)
@@ -246,6 +266,11 @@ def add_file(
class Dataset:
+ """Datasets are versioned, documented files.
+
+ See https://github.com/BLSQ/openhexa/wiki/Using-the-OpenHEXA-SDK#working-with-datasets for more information.
+ """
+
_latest_version = None
def __init__(
@@ -260,9 +285,8 @@ def __init__(
self.name = name
self.description = description
- def create_version(self, name: typing.Any):
- # Check that all files exist
-
+ def create_version(self, name: typing.Any) -> DatasetVersion:
+ """Build a dataset version, save it and return it."""
response = graphql(
"""
mutation createDatasetVersion($input: CreateDatasetVersionInput!) {
@@ -299,7 +323,11 @@ def create_version(self, name: typing.Any):
return self.latest_version
@property
- def latest_version(self):
+ def latest_version(self) -> typing.Optional[DatasetVersion]:
+ """Return the latest version, if any.
+
+ This property method will query the backend to try to fetch the latest version.
+ """
if self._latest_version is None:
data = graphql(
"""
@@ -329,8 +357,14 @@ def latest_version(self):
return self._latest_version
@property
- def versions(self):
+ def versions(self) -> VersionsIterator:
+ """Build and return an iterator for versions."""
return VersionsIterator(dataset=self, per_page=10)
def __repr__(self) -> str:
+ """Safe representation of the dataset."""
return f""
+
+
+class FileNotFound(Exception):
+ """Raised whenever an attempt is made to get a file that does not exist."""
diff --git a/openhexa/sdk/pipelines/__init__.py b/openhexa/sdk/pipelines/__init__.py
index c2b3411..aa51bcd 100644
--- a/openhexa/sdk/pipelines/__init__.py
+++ b/openhexa/sdk/pipelines/__init__.py
@@ -1,11 +1,18 @@
+"""Pipelines package.
+
+See https://github.com/BLSQ/openhexa/wiki/User-manual#using-pipelines and
+https://github.com/BLSQ/openhexa/wiki/Writing-OpenHEXA-pipelines for more information about OpenHEXA pipelines.
+"""
+
from .parameter import parameter
-from .pipeline import pipeline
+from .pipeline import Pipeline, pipeline
from .run import current_run
from .runtime import download_pipeline, import_pipeline
from .utils import get_local_workspace_config
__all__ = [
"pipeline",
+ "Pipeline",
"parameter",
"current_run",
"import_pipeline",
diff --git a/openhexa/sdk/pipelines/parameter.py b/openhexa/sdk/pipelines/parameter.py
index 53bddab..04eafc6 100644
--- a/openhexa/sdk/pipelines/parameter.py
+++ b/openhexa/sdk/pipelines/parameter.py
@@ -1,40 +1,43 @@
+"""Pipeline parameters classes and functions.
+
+See https://github.com/BLSQ/openhexa/wiki/Writing-OpenHEXA-pipelines#pipeline-parameters for more information.
+"""
import re
import typing
+
+from openhexa.sdk.workspaces import workspace
from openhexa.sdk.workspaces.connection import (
+ Connection,
+ CustomConnection,
DHIS2Connection,
+ GCSConnection,
IASOConnection,
PostgreSQLConnection,
S3Connection,
- GCSConnection,
)
-from openhexa.sdk.workspaces import workspace
from openhexa.sdk.workspaces.workspace import ConnectionDoesNotExist
-class ParameterValueError(Exception):
- pass
-
-
class ParameterType:
- """Base class for parameter types. Those parameter types are used when using the @parameter decorator"""
+ """Base class for parameter types. Those parameter types are used when using the @parameter decorator."""
def spec_type(self) -> str:
- """Returns a type string for the specs that are sent to the backend."""
-
+ """Return a type string for the specs that are sent to the backend."""
raise NotImplementedError
@property
- def expected_type(self) -> typing.Type:
+ def expected_type(self) -> type:
"""Returns the python type expected for values."""
-
raise NotImplementedError
@property
- def accepts_choice(self) -> bool:
+ def accepts_choices(self) -> bool:
+ """Return True only if the parameter type supports the "choices" optional argument."""
return True
@property
def accepts_multiple(self) -> bool:
+ """Return True only if the parameter type supports multiple values."""
return True
@staticmethod
@@ -44,12 +47,10 @@ def normalize(value: typing.Any) -> typing.Any:
This can be used to handle empty values and normalize them to None, or to perform type conversions, allowing us
to allow multiple input types but still normalize everything to a single type.
"""
-
return value
def validate(self, value: typing.Optional[typing.Any]) -> typing.Optional[typing.Any]:
"""Validate the provided value for this type."""
-
if not isinstance(value, self.expected_type):
raise ParameterValueError(
f"Invalid type for value {value} (expected {self.expected_type}, got {type(value)})"
@@ -58,23 +59,30 @@ def validate(self, value: typing.Optional[typing.Any]) -> typing.Optional[typing
return value
def validate_default(self, value: typing.Optional[typing.Any]):
+ """Validate the default value configured for this type."""
self.validate(value)
def __str__(self) -> str:
+ """Cast parameter as string."""
return str(self.expected_type)
class StringType(ParameterType):
+ """Type class for string parameters."""
+
@property
def spec_type(self) -> str:
+ """Return a type string for the specs that are sent to the backend."""
return "str"
@property
- def expected_type(self) -> typing.Type:
+ def expected_type(self) -> type:
+ """Returns the python type expected for values."""
return str
@staticmethod
def normalize(value: typing.Any) -> typing.Optional[str]:
+ """Strip leading and trailing whitespaces and convert empty strings to None."""
if isinstance(value, str):
normalized_value = value.strip()
else:
@@ -86,6 +94,7 @@ def normalize(value: typing.Any) -> typing.Optional[str]:
return normalized_value
def validate_default(self, value: typing.Optional[typing.Any]):
+ """Validate the default value configured for this type."""
if value == "":
raise ParameterValueError("Empty values are not accepted.")
@@ -93,44 +102,59 @@ def validate_default(self, value: typing.Optional[typing.Any]):
class Boolean(ParameterType):
+ """Type class for boolean parameters."""
+
@property
def spec_type(self) -> str:
+ """Return a type string for the specs that are sent to the backend."""
return "bool"
@property
- def expected_type(self) -> typing.Type:
+ def expected_type(self) -> type:
+ """Returns the python type expected for values."""
return bool
@property
- def accepts_choice(self) -> bool:
+ def accepts_choices(self) -> bool:
+ """Return a type string for the specs that are sent to the backend."""
return False
@property
def accepts_multiple(self) -> bool:
+ """Return a type string for the specs that are sent to the backend."""
return False
class Integer(ParameterType):
+ """Type class for integer parameters."""
+
@property
def spec_type(self) -> str:
+ """Return a type string for the specs that are sent to the backend."""
return "int"
@property
- def expected_type(self) -> typing.Type:
+ def expected_type(self) -> type:
+ """Returns the python type expected for values."""
return int
class Float(ParameterType):
+ """Type class for float parameters."""
+
@property
def spec_type(self) -> str:
+ """Return a type string for the specs that are sent to the backend."""
return "float"
@property
- def expected_type(self) -> typing.Type:
+ def expected_type(self) -> type:
+ """Returns the python type expected for values."""
return float
@staticmethod
def normalize(value: typing.Any) -> typing.Any:
+ """Normalize int values to float values if appropriate."""
if isinstance(value, int):
return float(value)
@@ -138,15 +162,20 @@ def normalize(value: typing.Any) -> typing.Any:
class ConnectionParameterType(ParameterType):
+ """Abstract base class for connection parameter type classes."""
+
@property
- def accepts_choice(self) -> bool:
+ def accepts_choices(self) -> bool:
+ """Return True only if the parameter type supports the "choice values."""
return False
@property
def accepts_multiple(self) -> bool:
+ """Return True only if the parameter type supports multiple values."""
return False
def validate_default(self, value: typing.Optional[typing.Any]):
+ """Validate the default value configured for this type."""
if value is None:
return
@@ -155,7 +184,8 @@ def validate_default(self, value: typing.Optional[typing.Any]):
elif value == "":
raise ParameterValueError("Empty values are not accepted.")
- def validate(self, value: typing.Optional[typing.Any]) -> typing.Optional[str]:
+ def validate(self, value: typing.Optional[typing.Any]) -> Connection:
+ """Validate the provided value for this type."""
if not isinstance(value, str):
raise ParameterValueError(f"Invalid type for value {value} (expected {str}, got {type(value)})")
@@ -164,86 +194,117 @@ def validate(self, value: typing.Optional[typing.Any]) -> typing.Optional[str]:
except ConnectionDoesNotExist as e:
raise ParameterValueError(str(e))
- def to_connection(self, value: str) -> typing.Any:
+ def to_connection(self, value: str) -> Connection:
+ """Build a connection instance from the provided value (which should be a connection identifier)."""
raise NotImplementedError
class PostgreSQLConnectionType(ConnectionParameterType):
+ """Type class for PostgreSQL connections."""
+
@property
def spec_type(self) -> str:
+ """Return a type string for the specs that are sent to the backend."""
return "postgresql"
@property
- def expected_type(self) -> typing.Type:
+ def expected_type(self) -> type:
+ """Returns the python type expected for values."""
return PostgreSQLConnectionType
- def to_connection(self, value: str) -> typing.Any:
+ def to_connection(self, value: str) -> PostgreSQLConnection:
+ """Build a PostgreSQL connection instance from the provided value (which should be a connection identifier)."""
return workspace.postgresql_connection(value)
class S3ConnectionType(ConnectionParameterType):
+ """Type class for S3 connections."""
+
@property
def spec_type(self) -> str:
+ """Return a type string for the specs that are sent to the backend."""
return "s3"
@property
- def expected_type(self) -> typing.Type:
+ def expected_type(self) -> type:
+ """Returns the python type expected for values."""
return S3ConnectionType
- def to_connection(self, value: str) -> typing.Any:
+ def to_connection(self, value: str) -> S3Connection:
+ """Build a S3 connection instance from the provided value (which should be a connection identifier)."""
return workspace.s3_connection(value)
class GCSConnectionType(ConnectionParameterType):
+ """Type class for GCS connections."""
+
@property
def spec_type(self) -> str:
+ """Return a type string for the specs that are sent to the backend."""
return "gcs"
@property
- def expected_type(self) -> typing.Type:
+ def expected_type(self) -> type:
+ """Returns the python type expected for values."""
return GCSConnectionType
- def to_connection(self, value: str) -> typing.Any:
+ def to_connection(self, value: str) -> GCSConnection:
+ """Build a GCS connection instance from the provided value (which should be a connection identifier)."""
return workspace.gcs_connection(value)
class DHIS2ConnectionType(ConnectionParameterType):
+ """Type class for DHIS2 connections."""
+
@property
def spec_type(self) -> str:
+ """Return a type string for the specs that are sent to the backend."""
return "dhis2"
@property
- def expected_type(self) -> typing.Type:
+ def expected_type(self) -> type:
+ """Returns the python type expected for values."""
return DHIS2ConnectionType
- def to_connection(self, value: str) -> typing.Any:
+ def to_connection(self, value: str) -> DHIS2Connection:
+ """Build a DHIS2 connection instance from the provided value (which should be a connection identifier)."""
return workspace.dhis2_connection(value)
class IASOConnectionType(ConnectionParameterType):
+ """Type class for IASO connections."""
+
@property
def spec_type(self) -> str:
+ """Return a type string for the specs that are sent to the backend."""
return "iaso"
@property
- def expected_type(self) -> typing.Type:
+ def expected_type(self) -> type:
+ """Returns the python type expected for values."""
return IASOConnectionType
- def to_connection(self, value: str) -> typing.Any:
+ def to_connection(self, value: str) -> IASOConnection:
+ """Build a IASO connection instance from the provided value (which should be a connection identifier)."""
return workspace.iaso_connection(value)
class CustomConnectionType(ConnectionParameterType):
+ """Type class for custom connections."""
+
@property
def spec_type(self) -> str:
+ """Return a type string for the specs that are sent to the backend."""
return "custom"
@property
- def expected_type(self) -> typing.Type:
+ def expected_type(self) -> type:
+ """Returns the python type expected for values."""
return str
- def to_connection(self, value: str) -> typing.Any:
- return workspace.postgresql_connection(value)
+ def to_connection(self, value: str) -> CustomConnection:
+ """Build a custom connection instance from the provided value (which should be a connection identifier)."""
+ return workspace.custom_connection(value)
TYPES_BY_PYTHON_TYPE = {
@@ -259,10 +320,6 @@ def to_connection(self, value: str) -> typing.Any:
}
-class InvalidParameterError(Exception):
- pass
-
-
class Parameter:
"""Pipeline parameter class. Contains validation logic specs generation logic."""
@@ -270,7 +327,7 @@ def __init__(
self,
code: str,
*,
- type: typing.Union[typing.Type[str], typing.Type[int], typing.Type[bool]],
+ type: typing.Union[type[str], type[int], type[bool]],
name: typing.Optional[str] = None,
choices: typing.Optional[typing.Sequence] = None,
help: typing.Optional[str] = None,
@@ -280,7 +337,8 @@ def __init__(
):
if re.match("^[a-z_][a-z_0-9]+$", code) is None:
raise InvalidParameterError(
- f"Invalid parameter code provided ({code}). Parameter must start with a letter or an underscore, and can only contain lower case letters, numbers and underscores."
+ f"Invalid parameter code provided ({code}). Parameter must start with a letter or an underscore, "
+ f"and can only contain lower case letters, numbers and underscores."
)
self.code = code
@@ -290,11 +348,12 @@ def __init__(
except KeyError:
valid_parameter_types = [str(k) for k in TYPES_BY_PYTHON_TYPE.keys()]
raise InvalidParameterError(
- f"Invalid parameter type provided ({type}). Valid parameter types are {', '.join(valid_parameter_types)}"
+ f"Invalid parameter type provided ({type}). "
+ f"Valid parameter types are {', '.join(valid_parameter_types)}"
)
if choices is not None:
- if not self.type.accepts_choice:
+ if not self.type.accepts_choices:
raise InvalidParameterError(f"Parameters of type {self.type} don't accept choices.")
elif len(choices) == 0:
raise InvalidParameterError("Choices, if provided, cannot be empty.")
@@ -318,8 +377,7 @@ def __init__(
self.default = default
def validate(self, value: typing.Any) -> typing.Any:
- """Validates the provided value against the parameter, taking required / default options into account."""
-
+ """Validate the provided value against the parameter, taking required / default options into account."""
if self.multiple:
return self._validate_multiple(value)
else:
@@ -383,9 +441,8 @@ def _validate_default(self, default: typing.Any, multiple: bool):
except ParameterValueError:
raise InvalidParameterError(f"The default value for {self.code} is not valid.")
- def parameter_spec(self):
- """Generates specification for this parameter, to be provided to the OpenHexa backend."""
-
+ def parameter_spec(self) -> dict[str, typing.Any]:
+ """Build specification for this parameter, to be provided to the OpenHEXA backend."""
return {
"type": self.type.spec_type,
"required": self.required,
@@ -402,15 +459,15 @@ def parameter(
code: str,
*,
type: typing.Union[
- typing.Type[str],
- typing.Type[int],
- typing.Type[bool],
- typing.Type[float],
- typing.Type[DHIS2Connection],
- typing.Type[IASOConnection],
- typing.Type[PostgreSQLConnection],
- typing.Type[GCSConnection],
- typing.Type[S3Connection],
+ type[str],
+ type[int],
+ type[bool],
+ type[float],
+ type[DHIS2Connection],
+ type[IASOConnection],
+ type[PostgreSQLConnection],
+ type[GCSConnection],
+ type[S3Connection],
],
name: typing.Optional[str] = None,
choices: typing.Optional[typing.Sequence] = None,
@@ -419,7 +476,7 @@ def parameter(
required: bool = True,
multiple: bool = False,
):
- """Decorator that attaches a parameter to an OpenHexa pipeline.
+ """Decorate a pipeline function by attaching a parameter to it..
This decorator must be used on a function decorated by the @pipeline decorator.
@@ -470,17 +527,31 @@ def decorator(fun):
class FunctionWithParameter:
- """This class serves as a wrapper for functions decorated with the @parameter decorator."""
+ """Wrapper class for pipeline functions decorated with the @parameter decorator."""
def __init__(self, function, added_parameter: Parameter):
self.function = function
self.parameter = added_parameter
- def __call__(self, *args, **kwargs):
- return self.function(*args, **kwargs)
-
- def get_all_parameters(self):
+ def get_all_parameters(self) -> list[Parameter]:
+ """Go through the decorators chain to find all pipeline parameters."""
if isinstance(self.function, FunctionWithParameter):
return [self.parameter, *self.function.get_all_parameters()]
return [self.parameter]
+
+ def __call__(self, *args, **kwargs):
+ """Call the decorated pipeline function."""
+ return self.function(*args, **kwargs)
+
+
+class InvalidParameterError(Exception):
+ """Raised whenever parameter options (usually passed to the @parameter decorator) are invalid."""
+
+ pass
+
+
+class ParameterValueError(Exception):
+ """Raised whenever values for a parameter provided for a pipeline run are invalid."""
+
+ pass
diff --git a/openhexa/sdk/pipelines/pipeline.py b/openhexa/sdk/pipelines/pipeline.py
index 0d2b32a..260aadb 100644
--- a/openhexa/sdk/pipelines/pipeline.py
+++ b/openhexa/sdk/pipelines/pipeline.py
@@ -1,4 +1,9 @@
-from __future__ import annotations
+"""Main pipeline module containing the building blocks for OpenHEXA pipelines.
+
+See https://github.com/BLSQ/openhexa/wiki/User-manual#using-pipelines and
+https://github.com/BLSQ/openhexa/wiki/Writing-OpenHEXA-pipelines for more information about OpenHEXA pipelines.
+"""
+
import argparse
import datetime
@@ -14,64 +19,38 @@
import requests
from multiprocess import get_context # NOQA
-from openhexa.sdk.utils import Environments, get_environment
+from openhexa.sdk.utils import Environment, get_environment
from .parameter import (
FunctionWithParameter,
Parameter,
ParameterValueError,
)
-from .task import PipelineWithTask
+from .task import PipelineWithTask, Task
from .utils import get_local_workspace_config
logger = getLogger(__name__)
-class PipelineConfigError(Exception):
- pass
-
+class Pipeline:
+ """OpenHEXA pipeline class.
-def pipeline(
- code: str, *, name: str = None, timeout: int = None
-) -> typing.Callable[[typing.Callable[..., typing.Any]], "Pipeline"]:
- """Decorator that turns a Python function into an OpenHexa pipeline.
+ Pipeline are usually instantiated through the @pipeline decorator.
- Parameters
+ Attributes
----------
code : str
- An identifier for the pipeline (should be unique within the workspace where the pipeline is deployed)
- name : str, optional
- An optional name for the pipeline (will be used instead of the code in the web interface)
- timeout : int, optional
- An optional timeout, in seconds, after which the pipeline run will be terminated (if not provided, a default
- timeout will be applied by the OpenHexa backend)
-
- Returns
- -------
- typing.Callable
- A decorator that returns a Pipeline
-
+ A unique code to identify the pipeline within a workspace.
+ name : str
+ A user-friendly name for the pipeline (will be displayed in the web interface).
+ function: typing.Callable
+ The actual pipeline function.
+ parameters : typing.Sequence[Parameter]
+ A list of Parameter instance corresponding to the pipeline parameters.
+ timeout : int
+ The timeout in seconds after which the pipeline will be killed.
"""
- if any(c not in string.ascii_lowercase + string.digits + "_-" for c in code):
- raise Exception("Pipeline code should contains only lower case letters, digits, '_' and '-'")
-
- def decorator(fun):
- if isinstance(fun, FunctionWithParameter):
- parameters = fun.get_all_parameters()
- else:
- parameters = []
-
- return Pipeline(code, name, fun, parameters, timeout)
-
- return decorator
-
-
-class PipelineRunError(Exception):
- pass
-
-
-class Pipeline:
def __init__(
self,
code: str,
@@ -87,12 +66,34 @@ def __init__(
self.timeout = timeout
self.tasks = []
- def task(self, function):
- """task decorator"""
-
+ def task(self, function) -> PipelineWithTask:
+ """Task decorator.
+
+ Examples
+ --------
+ >>> @pipeline("my-pipeline")
+ ... def my_pipeline():
+ ... result_1 = task1()
+ ... task2(result_1)
+ ...
+ ... @my_pipeline.task
+ ... def task_1() -> int:
+ ... return 42
+ ...
+ ... @my_pipeline.task
+ ... def task_2(foo: int):
+ ... pass
+ """
return PipelineWithTask(function, self)
- def run(self, config: typing.Dict[str, typing.Any]):
+ def run(self, config: dict[str, typing.Any]):
+ """Run the pipeline using the provided config.
+
+ Parameters
+ ----------
+ config : typing.Dict[str, typing.Any]
+ The parameter values to use for this pipeline run.
+ """
now = datetime.datetime.utcnow().replace(microsecond=0).isoformat()
print(f'{now} Starting pipeline "{self.code}"')
@@ -116,7 +117,7 @@ def run(self, config: typing.Dict[str, typing.Any]):
completed = 0
while True:
- tasks = self.get_available_tasks()
+ tasks = self._get_available_tasks()
# filter already pooled task, even if they are not finished
tasks = [t for t in tasks if not t.pooled]
@@ -166,12 +167,13 @@ def run(self, config: typing.Dict[str, typing.Any]):
print(f'{now} Successfully completed pipeline "{self.code}"')
- def get_available_tasks(self):
- return [task for task in self.tasks if task.is_ready()]
-
- def parameters_spec(self):
+ def parameters_spec(self) -> list[dict[str, typing.Any]]:
+ """Return the individual specifications of all the parameters of this pipeline."""
return [arg.parameter_spec() for arg in self.parameters]
+ def _get_available_tasks(self) -> list[Task]:
+ return [task for task in self.tasks if task.is_ready()]
+
def _update_progress(self, progress: int):
if self._connected:
token = os.environ["HEXA_TOKEN"]
@@ -193,19 +195,25 @@ def _update_progress(self, progress: int):
print(f"Progress update: {progress}%")
@property
- def _connected(self):
+ def _connected(self) -> bool:
env = get_environment()
- return env == Environments.CLOUD_PIPELINE and "HEXA_SERVER_URL" in os.environ
- def __call__(self, config: typing.Optional[typing.Dict[str, typing.Any]] = None):
+ return env == Environment.CLOUD_PIPELINE and "HEXA_SERVER_URL" in os.environ
+
+ def __call__(self, config: typing.Optional[dict[str, typing.Any]] = None):
+ """Call the pipeline by running it, after having configured the environment.
+
+ This method can be called with an explicit configuration. If no configuration is provided, it will parse the
+ command-line arguments to build it.
+ """
# Handle local workspace config for dev / testing, if appropriate
- if get_environment() == Environments.LOCAL_PIPELINE:
+ if get_environment() == Environment.LOCAL_PIPELINE:
os.environ.update(get_local_workspace_config(Path("/home/hexa/pipeline")))
# User can run their pipeline using `python pipeline.py`. It's considered as a standalone usage of the library.
# Since we still support this use case for the moment, we'll try to load the workspace.yaml
# at the path of the file
- elif get_environment() == Environments.STANDALONE:
+ elif get_environment() == Environment.STANDALONE:
os.environ.update(get_local_workspace_config(Path(sys.argv[0]).parent))
if config is None: # Called without arguments, in the pipeline file itself
@@ -221,7 +229,7 @@ def __call__(self, config: typing.Optional[typing.Dict[str, typing.Any]] = None)
"config file with the --config-file/-f argument."
)
if args.config_file is not None:
- with open(args.config_file, "r") as cf:
+ with open(args.config_file) as cf:
try:
config = json.load(cf)
except json.JSONDecodeError:
@@ -236,3 +244,59 @@ def __call__(self, config: typing.Optional[typing.Dict[str, typing.Any]] = None)
config = {}
self.run(config)
+
+
+def pipeline(
+ code: str, *, name: str = None, timeout: int = None
+) -> typing.Callable[[typing.Callable[..., typing.Any]], Pipeline]:
+ """Decorate a Python function as an OpenHEXA pipeline.
+
+ Parameters
+ ----------
+ code : str
+ An identifier for the pipeline (should be unique within the workspace where the pipeline is deployed)
+ name : str, optional
+ An optional name for the pipeline (will be used instead of the code in the web interface)
+ timeout : int, optional
+ An optional timeout, in seconds, after which the pipeline run will be terminated (if not provided, a default
+ timeout will be applied by the OpenHEXA backend)
+
+ Returns
+ -------
+ typing.Callable
+ A decorator that returns a Pipeline
+
+ Examples
+ --------
+ >>> @pipeline("my-pipeline")
+ ... def my_pipeline():
+ ... a_task()
+ ...
+ ... @my_pipeline.task
+ ... def a_task() -> int:
+ ... return 42
+ """
+ if any(c not in string.ascii_lowercase + string.digits + "_-" for c in code):
+ raise Exception("Pipeline code should contains only lower case letters, digits, '_' and '-'")
+
+ def decorator(fun):
+ if isinstance(fun, FunctionWithParameter):
+ parameters = fun.get_all_parameters()
+ else:
+ parameters = []
+
+ return Pipeline(code, name, fun, parameters, timeout)
+
+ return decorator
+
+
+class PipelineConfigError(Exception):
+ """Error raised whenver the config passed to the pipeline run method is invalid."""
+
+ pass
+
+
+class PipelineRunError(Exception):
+ """Generic pipeline runtime error, raised whenever user code raises an exception."""
+
+ pass
diff --git a/openhexa/sdk/pipelines/run.py b/openhexa/sdk/pipelines/run.py
index c904e14..c79bdaa 100644
--- a/openhexa/sdk/pipelines/run.py
+++ b/openhexa/sdk/pipelines/run.py
@@ -1,22 +1,29 @@
+"""Pipeline run module."""
+
import datetime
import os
import typing
-from pathlib import Path
-from openhexa.sdk.utils import Environments, get_environment, graphql
+from openhexa.sdk.utils import Environment, get_environment, graphql
from openhexa.sdk.workspaces import workspace
class CurrentRun:
+ """Represents the current run of a pipeline.
+
+ CurrentRun instances allow pipeline developers to interact with the OpenHEXA backend, by sending messages and
+ adding outputs that will be available through the web interface.
+ """
+
@property
def _connected(self):
return "HEXA_SERVER_URL" in os.environ
- @property
- def tmp_path(self):
- return Path("~/tmp/")
-
def add_file_output(self, path: str):
+ """Record a run output for a file creation operation.
+
+ This output will be visible in the web interface, on the pipeline run page.
+ """
stripped_path = path.replace(workspace.files_path, "")
name = stripped_path.strip("/")
if self._connected:
@@ -37,6 +44,10 @@ def add_file_output(self, path: str):
print(f"Sending output with path {stripped_path}")
def add_database_output(self, table_name: str):
+ """Record a run output for a database operation.
+
+ This output will be visible in the web interface, on the pipeline run page.
+ """
if self._connected:
graphql(
"""
@@ -55,18 +66,23 @@ def add_database_output(self, table_name: str):
print(f"Sending output with table_name {table_name}")
def log_debug(self, message: str):
+ """Log a message with the DEBUG priority."""
self._log_message("DEBUG", message)
def log_info(self, message: str):
+ """Log a message with the INFO priority."""
self._log_message("INFO", message)
def log_warning(self, message: str):
+ """Log a message with the WARNING priority."""
self._log_message("WARNING", message)
def log_error(self, message: str):
+ """Log a message with the ERROR priority."""
self._log_message("ERROR", message)
def log_critical(self, message: str):
+ """Log a message with the CRITICAL priority."""
self._log_message("CRITICAL", message)
def _log_message(
@@ -91,7 +107,7 @@ def _log_message(
print(now, priority, message)
-if get_environment() == Environments.CLOUD_JUPYTER:
+if get_environment() == Environment.CLOUD_JUPYTER:
current_run = None
else:
current_run = CurrentRun()
diff --git a/openhexa/sdk/pipelines/runtime.py b/openhexa/sdk/pipelines/runtime.py
index bb4fdfe..d769d54 100644
--- a/openhexa/sdk/pipelines/runtime.py
+++ b/openhexa/sdk/pipelines/runtime.py
@@ -1,3 +1,5 @@
+"""Utilities used by containerized pipeline runners to import and download pipelines."""
+
import base64
import importlib
import io
@@ -11,6 +13,7 @@
def import_pipeline(pipeline_dir_path: str):
+ """Import pipeline code within provided path using importlib."""
pipeline_dir = os.path.abspath(pipeline_dir_path)
sys.path.append(pipeline_dir)
pipeline_package = importlib.import_module("pipeline")
@@ -19,7 +22,8 @@ def import_pipeline(pipeline_dir_path: str):
return pipeline
-def download_pipeline(url: str, token: str, run_id: str, target_dir):
+def download_pipeline(url: str, token: str, run_id: str, target_dir: str):
+ """Download pipeline code and unzip it into the target directory."""
r = requests.post(
url + "/graphql/",
headers={"Authorization": f"Bearer {token}"},
diff --git a/openhexa/sdk/pipelines/task.py b/openhexa/sdk/pipelines/task.py
index 31b9b1d..f545abf 100644
--- a/openhexa/sdk/pipelines/task.py
+++ b/openhexa/sdk/pipelines/task.py
@@ -1,3 +1,8 @@
+"""Classes and functions related to pipeline tasks.
+
+See https://github.com/BLSQ/openhexa/wiki/Writing-OpenHEXA-pipelines#pipelines-and-tasks for more information.
+"""
+
from __future__ import annotations
import datetime
@@ -7,6 +12,11 @@
class TaskCom:
+ """Lightweight data transfer object allowing tasks to communicate.
+
+ TaskCom instances also allow us to build the pipeline dependency graph.
+ """
+
def __init__(self, task):
self.result = task.result
self.start_time = task.start_time
@@ -14,6 +24,11 @@ def __init__(self, task):
class Task:
+ """Tasks are pipeline data processing code units.
+
+ See https://github.com/BLSQ/openhexa/wiki/Writing-OpenHEXA-pipelines#pipelines-and-tasks for more information.
+ """
+
def __init__(self, function: typing.Callable):
self.name = function.__name__
self.compute = function
@@ -26,27 +41,11 @@ def __init__(self, function: typing.Callable):
self.active = False
self.pooled = False
- def __call__(self, *task_args, **task_kwargs):
- self.active = True # uncalled tasks will be skipped
- # check that all inputs are tasks
- self.task_args = task_args
- self.task_kwargs = task_kwargs
- return self
-
- def __repr__(self):
- return self.name
-
- def get_node_inputs(self):
- inputs = []
- for a in self.task_args:
- if issubclass(type(a), Task):
- inputs.append(a)
- for k, a in self.task_kwargs.items():
- if issubclass(type(a), Task):
- inputs.append(a)
- return inputs
+ def is_ready(self) -> bool:
+ """Determine whether the task is ready to be run.
- def is_ready(self):
+ This involves checking whether tasks higher up in the dependency graph have been executed.
+ """
if not self.active:
return False
@@ -59,24 +58,32 @@ def is_ready(self):
return True if self.end_time is None else False
- def get_tasks_ready(self):
+ def get_ready_tasks(self) -> list[Task]:
+ """Find and return all tasks that can be launched at this point in time."""
tasks = []
for a in self.task_args:
if issubclass(type(a), Task):
if a.is_ready():
tasks.append(a)
else:
- tasks += a.get_tasks_ready()
+ tasks += a.get_ready_tasks()
for k, a in self.task_kwargs.items():
if issubclass(type(a), Task):
if a.is_ready():
tasks.append(a)
else:
- tasks += a.get_tasks_ready()
+ tasks += a.get_ready_tasks()
return list(set(tasks))
- def run(self):
+ def run(self) -> TaskCom:
+ """Run the task.
+
+ Returns
+ -------
+ TaskCom
+ A TaskCom instance which can in turn be passed to other tasks.
+ """
if self.end_time:
# already executed, return previous result
return self.result
@@ -106,22 +113,33 @@ def run(self):
# done!
return TaskCom(self)
- def stateless_run(self):
- self.result = None
- self.start_time, self.end_time = None, None
- return self.run()
+ def __call__(self, *task_args, **task_kwargs):
+ """Wrap the task with args and kwargs and return it."""
+ self.active = True # uncalled tasks will be skipped
+ # check that all inputs are tasks
+ self.task_args = task_args
+ self.task_kwargs = task_kwargs
+
+ return self
+
+ def __repr__(self):
+ """Representation of the task using its name."""
+ return self.name
class PipelineWithTask:
+ """Pipeline with attached tasks, usually through the @task decorator."""
+
def __init__(
self,
function: typing.Callable,
- pipeline: openhexa.sdk.pipelines.pipeline.Pipeline,
+ pipeline: openhexa.sdk.pipelines.Pipeline,
):
self.function = function
self.pipeline = pipeline
- def __call__(self, *task_args, **task_kwargs):
+ def __call__(self, *task_args, **task_kwargs) -> Task:
+ """Attach the new task to the decorated pipeline and return it."""
task = Task(self.function)(*task_args, **task_kwargs)
self.pipeline.tasks.append(task)
return task
diff --git a/openhexa/sdk/pipelines/utils.py b/openhexa/sdk/pipelines/utils.py
index 6e70ee0..7fae0fc 100644
--- a/openhexa/sdk/pipelines/utils.py
+++ b/openhexa/sdk/pipelines/utils.py
@@ -1,3 +1,5 @@
+"""Utilities for running local pipelines."""
+
from pathlib import Path
from tempfile import mkdtemp
@@ -6,6 +8,8 @@
class LocalWorkspaceConfigError(Exception):
+ """Raised whenever the local workspace config file does not exist or is invalid."""
+
pass
@@ -18,7 +22,6 @@ def get_local_workspace_config(path: Path):
This is obviously brittle as it relies on setting the correct env variables keys, any changes upstream must
be reflected here.
"""
-
env_vars = {}
# This will only work when running the pipeline using "python pipeline.py"
@@ -29,7 +32,7 @@ def get_local_workspace_config(path: Path):
"To work with pipelines locally, you need a workspace.yaml file in the same directory as your pipeline file"
)
- with open(local_workspace_config_path.resolve(), "r") as local_workspace_config_file:
+ with open(local_workspace_config_path.resolve()) as local_workspace_config_file:
local_workspace_config = yaml.safe_load(local_workspace_config_file)
# Database config
if "database" in local_workspace_config:
diff --git a/openhexa/sdk/utils.py b/openhexa/sdk/utils.py
index fb983b2..4bd9cf1 100644
--- a/openhexa/sdk/utils.py
+++ b/openhexa/sdk/utils.py
@@ -1,3 +1,5 @@
+"""Miscellaneous utility functions."""
+
import abc
import enum
import os
@@ -6,7 +8,9 @@
import requests
-class Environments(enum.Enum):
+class Environment(enum.Enum):
+ """Enumeration of supported runtime environments."""
+
LOCAL_PIPELINE = "LOCAL_PIPELINE"
CLOUD_PIPELINE = "CLOUD_PIPELINE"
CLOUD_JUPYTER = "CLOUD_JUPYTER"
@@ -14,13 +18,17 @@ class Environments(enum.Enum):
def get_environment():
+ """Get the environment from the HEXA_ENVIRONMENT (see Environment enum)."""
env = os.environ.get("HEXA_ENVIRONMENT", "STANDALONE").upper()
- if env not in Environments.__members__:
+
+ try:
+ return Environment[env]
+ except KeyError:
raise ValueError(f"Invalid environment: {env}")
- return Environments[env]
-def graphql(operation: str, variables: typing.Optional[typing.Dict[str, typing.Any]] = None):
+def graphql(operation: str, variables: typing.Optional[dict[str, typing.Any]] = None) -> dict[str, typing.Any]:
+ """Performa GraphQL query."""
auth_token = os.environ[
"HEXA_TOKEN"
] # Works for notebooks with the membership token and pipelines with the run token
@@ -42,7 +50,7 @@ def graphql(operation: str, variables: typing.Optional[typing.Dict[str, typing.A
return body["data"]
-class Iterator(object, metaclass=abc.ABCMeta):
+class Iterator(metaclass=abc.ABCMeta):
"""A generic class for iterating through API list responses."""
def __init__(
@@ -68,42 +76,35 @@ def __init__(
"""int: The total number of results fetched so far."""
def _items_iter(self):
- """Iterator for each item returned."""
for page in self._page_iter(increment=False):
for item in page:
self.num_results += 1
yield item
- def __iter__(self):
- """Iterator for each item returned.
-
- Returns:
- types.GeneratorType[Any]: A generator of items from the API.
-
- Raises:
- ValueError: If the iterator has already been started.
- """
+ def __iter__(self) -> typing.Generator[typing.Any, None, None]:
+ """Implement __iter()."""
if self._started:
raise ValueError("Iterator has already started", self)
self._started = True
+
return self._items_iter()
def __next__(self):
+ """Implement next()."""
if self.__active_iterator is None:
self.__active_iterator = iter(self)
- return next(self.__active_iterator)
- def _page_iter(self, increment):
- """Generator of pages of API responses.
+ return next(self.__active_iterator)
- Args:
- increment (bool): Flag indicating if the total number of results
- should be incremented on each page. This is useful since a page
- iterator will want to increment by results per page while an
- items iterator will want to increment per item.
+ def _page_iter(self, increment: bool):
+ """Generate pages of API responses.
- Yields:
- Page: each page of items from the API.
+ Parameters
+ ----------
+ increment : bool
+ Flag indicating if the total number of results should be incremented on each page.
+ This is useful since a page iterator will want to increment by results per page while an
+ items iterator will want to increment per item.
"""
page = self._next_page()
while page is not None:
@@ -120,24 +121,25 @@ def _next_page(self):
This does nothing and is intended to be over-ridden by subclasses
to return the next :class:`Page`.
- Raises:
+ Raises
+ ------
NotImplementedError: Always, this method is abstract.
"""
raise NotImplementedError
-class Page(object):
+class Page:
"""Single page of results in an iterator.
- Args:
- parent (Iterator): The iterator that owns
- the current page.
- items (Sequence[Any]): An iterable (that also defines __len__) of items
- from a raw API response.
- item_to_value (Callable[google.api_core.page_iterator.Iterator, Any]):
- Callable to convert an item from the type in the raw API response
- into the native object. Will be called with the iterator and a
- single item.
+ Parameters
+ ----------
+ parent : Iterator
+ The iterator that owns the current page.
+ items: Sequence[Any]
+ An iterable (that also defines __len__) of items from a raw API response.
+ item_to_value: Callable[dict[str, Any], Any]:
+ Callable to convert an item from the type in the raw API response into the native object.
+ Will be called with the iterator and a single item.
"""
def __init__(self, parent, items, item_to_value):
@@ -158,7 +160,7 @@ def remaining(self):
return self._remaining
def __iter__(self):
- """The :class:`Page` is an iterator of items."""
+ """Implement __iter__()."""
return self
def __next__(self):
@@ -171,7 +173,8 @@ def __next__(self):
return result
-def read_content(source: typing.Union[str, os.PathLike[str], typing.IO], encoding: str = "utf-8") -> str:
+def read_content(source: typing.Union[str, os.PathLike[str], typing.IO], encoding: str = "utf-8") -> bytes:
+ """Read file content and return it as bytes."""
# If source is a string or PathLike object
if isinstance(source, (str, os.PathLike)):
with open(os.fspath(source), "rb") as f:
diff --git a/openhexa/sdk/workspaces/__init__.py b/openhexa/sdk/workspaces/__init__.py
index 8faffd8..94bf94c 100644
--- a/openhexa/sdk/workspaces/__init__.py
+++ b/openhexa/sdk/workspaces/__init__.py
@@ -1,3 +1,8 @@
+"""Workspaces package.
+
+See https://github.com/BLSQ/openhexa/wiki/User-manual#about-workspaces for more information about OpenHEXA workspaces.
+"""
+
from .workspace import workspace
__all__ = ["workspace"]
diff --git a/openhexa/sdk/workspaces/connection.py b/openhexa/sdk/workspaces/connection.py
index 51b79af..2f6c87a 100644
--- a/openhexa/sdk/workspaces/connection.py
+++ b/openhexa/sdk/workspaces/connection.py
@@ -1,18 +1,38 @@
+"""Connection test module."""
+
import dataclasses
@dataclasses.dataclass
-class DHIS2Connection:
+class Connection:
+ """Abstract base class for connections."""
+
+ pass
+
+
+@dataclasses.dataclass
+class DHIS2Connection(Connection):
+ """DHIS2 connection.
+
+ See https://docs.dhis2.org/ for more information.
+ """
+
url: str
username: str
password: str
def __repr__(self):
+ """Safe representation of the DHIS2 connection (no credentials)."""
return f"DHIS2Connection(url='{self.url}', username='{self.username}')"
@dataclasses.dataclass
-class PostgreSQLConnection:
+class PostgreSQLConnection(Connection):
+ """PostgreSQL database connection.
+
+ See https://www.postgresql.org/docs/ for more information.
+ """
+
host: str
port: int
username: str
@@ -20,37 +40,76 @@ class PostgreSQLConnection:
database_name: str
def __repr__(self):
- return f"PostgreSQLConnection(host='{self.host}', port='{self.port}', username='{self.username}', database_name='{self.database_name}')"
+ """Safe representation of the PostgreSQL connection (no credentials)."""
+ return (
+ f"PostgreSQLConnection(host='{self.host}', port='{self.port}', username='{self.username}', "
+ f"database_name='{self.database_name}')"
+ )
@property
def url(self):
+ """Provide a URL to the PostgreSQL database.
+
+ The URL follows the official PostgreSQL specification.
+ (See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for more information)
+ """
return f"postgresql://{self.username}:{self.password}" f"@{self.host}:{self.port}/{self.database_name}"
@dataclasses.dataclass
-class S3Connection:
+class S3Connection(Connection):
+ """AWS S3 connection.
+
+ See https://docs.aws.amazon.com/s3/ for more information.
+ """
+
access_key_id: str
secret_access_key: str
bucket_name: str
def __repr__(self):
+ """Safe representation of the S3 connection (no credentials)."""
return f"S3Connection(bucket_name='{self.bucket_name}')"
@dataclasses.dataclass
-class GCSConnection:
+class GCSConnection(Connection):
+ """Google Cloud Storage connection.
+
+ See https://cloud.google.com/storage/docs for more information.
+ """
+
service_account_key: str
bucket_name: str
def __repr__(self):
+ """Safe representation of the GCS connection (no credentials)."""
return f"GCSConnection(bucket_name='{self.bucket_name}')"
+@dataclasses.dataclass
+class CustomConnection(Connection):
+ """Marker class for custom connections.
+
+ The actual class will be built dynamically through the Workspace.custom_connection() method.
+ """
+
+ def __repr__(self):
+ """Safe representation of the custom connection (no credentials)."""
+ return f"CustomConnection(name='{self.__class__.__name__.lower()}')"
+
+
@dataclasses.dataclass
class IASOConnection:
+ """IASO connection.
+
+ See https://github.com/BLSQ/iaso for more information.
+ """
+
url: str
username: str
password: str
def __repr__(self):
+ """Safe representation of the IASO connection (no credentials)."""
return f"IASOConnection(url='{self.url}', username='{self.username}')"
diff --git a/openhexa/sdk/workspaces/workspace.py b/openhexa/sdk/workspaces/workspace.py
index 743ec5c..c5ee69f 100644
--- a/openhexa/sdk/workspaces/workspace.py
+++ b/openhexa/sdk/workspaces/workspace.py
@@ -1,3 +1,8 @@
+"""Workspace-related classes and functions.
+
+See https://github.com/BLSQ/openhexa/wiki/User-manual#about-workspaces for more information.
+"""
+
import os
from dataclasses import make_dataclass
from warnings import warn
@@ -7,6 +12,7 @@
from ..datasets import Dataset
from ..utils import graphql
from .connection import (
+ CustomConnection,
DHIS2Connection,
GCSConnection,
IASOConnection,
@@ -16,30 +22,41 @@
class WorkspaceConfigError(Exception):
+ """Raised whenever the system cannot find an environment variable required to configure the current workspace."""
+
pass
class ConnectionDoesNotExist(Exception):
+ """Raised whenever an attempt is made to get a connection through an invalid identifier."""
+
pass
class CurrentWorkspace:
+ """Represents the currently configured OpenHEXA workspace, with its filesystem, database and connections."""
+
@property
- def _token(self):
+ def _token(self) -> str:
try:
return os.environ["HEXA_TOKEN"]
except KeyError:
- raise WorkspaceConfigError("Workspace's token is not available in this environment.")
+ raise WorkspaceConfigError("The workspace token is not available in this environment.")
@property
- def slug(self):
+ def slug(self) -> str:
+ """The unique slug of the workspace.
+
+ Slugs are used to identify the workspace.
+ """
try:
return os.environ["HEXA_WORKSPACE"]
except KeyError:
- raise WorkspaceConfigError("Workspace's slug is not available in this environment.")
+ raise WorkspaceConfigError("The workspace slug is not available in this environment.")
@property
- def database_host(self):
+ def database_host(self) -> str:
+ """The workspace database host."""
try:
return os.environ["WORKSPACE_DATABASE_HOST"]
except KeyError:
@@ -49,7 +66,8 @@ def database_host(self):
)
@property
- def database_username(self):
+ def database_username(self) -> str:
+ """The workspace database username."""
try:
return os.environ["WORKSPACE_DATABASE_USERNAME"]
except KeyError:
@@ -60,6 +78,7 @@ def database_username(self):
@property
def database_password(self):
+ """The workspace database password."""
try:
return os.environ.get("WORKSPACE_DATABASE_PASSWORD")
except KeyError:
@@ -70,6 +89,7 @@ def database_password(self):
@property
def database_name(self):
+ """The workspace database name."""
try:
return os.environ["WORKSPACE_DATABASE_DB_NAME"]
except KeyError:
@@ -80,6 +100,7 @@ def database_name(self):
@property
def database_port(self):
+ """The workspace database port."""
try:
return int(os.environ["WORKSPACE_DATABASE_PORT"])
except KeyError:
@@ -90,6 +111,11 @@ def database_port(self):
@property
def database_url(self):
+ """The workspace database URL.
+
+ The URL follows the official PostgreSQL specification.
+ (See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for more information)
+ """
return (
f"postgresql://{self.database_username}:{self.database_password}"
f"@{self.database_host}:{self.database_port}/{self.database_name}"
@@ -97,21 +123,41 @@ def database_url(self):
@property
def files_path(self) -> str:
- """Return the base path to the filesystem, without trailing slash"""
+ """The base path to the filesystem, without trailing slash.
+ Examples
+ --------
+ >>> f"{workspace.files_path}/some/path"
+ /home/hexa/workspace/some/path
+ """
# FIXME: This is a hack to make the SDK work in the context of the `python pipeline.py` command.
# We can remove this once we deprecate this way of running pipelines
return os.environ["WORKSPACE_FILES_PATH"] if "WORKSPACE_FILES_PATH" in os.environ else "/home/hexa/workspace"
@property
def tmp_path(self) -> str:
- """Return the base path to the tmp directory, without trailing slash"""
+ """The base path to the tmp directory, without trailing slash.
+ Examples
+ --------
+ >>> f"{workspace.tmp_path}/some/path"
+ /home/hexa/tmp/some/path
+ """
# FIXME: This is a hack to make the SDK work in the context of the `python pipeline.py` command.
# We can remove this once we deprecate this way of running pipelines
return os.environ["WORKSPACE_TMP_PATH"] if "WORKSPACE_TMP_PATH" in os.environ else "/home/hexa/tmp"
- def dhis2_connection(self, identifier: str = None, slug: str = None) -> DHIS2Connection:
+ @staticmethod
+ def dhis2_connection(identifier: str = None, slug: str = None) -> DHIS2Connection:
+ """Get a DHIS2 connection by its identifier.
+
+ Parameters
+ ----------
+ identifier : str
+ The identifier of the connection in the OpenHEXA backend
+ slug : str
+ Deprecated, same as identifier
+ """
identifier = identifier or slug
if slug is not None:
warn("'slug' is deprecated. Use 'identifier' instead.", DeprecationWarning, stacklevel=2)
@@ -125,7 +171,17 @@ def dhis2_connection(self, identifier: str = None, slug: str = None) -> DHIS2Con
return DHIS2Connection(url=url, username=username, password=password)
- def postgresql_connection(self, identifier: str = None, slug: str = None) -> PostgreSQLConnection:
+ @staticmethod
+ def postgresql_connection(identifier: str = None, slug: str = None) -> PostgreSQLConnection:
+ """Get a PostgreSQL connection by its identifier.
+
+ Parameters
+ ----------
+ identifier : str
+ The identifier of the connection in the OpenHEXA backend
+ slug : str
+ Deprecated, same as identifier
+ """
identifier = identifier or slug
if slug is not None:
warn("'slug' is deprecated. Use 'identifier' instead.", DeprecationWarning, stacklevel=2)
@@ -147,7 +203,17 @@ def postgresql_connection(self, identifier: str = None, slug: str = None) -> Pos
database_name=dbname,
)
- def s3_connection(self, identifier: str = None, slug: str = None) -> S3Connection:
+ @staticmethod
+ def s3_connection(identifier: str = None, slug: str = None) -> S3Connection:
+ """Get an AWS S3 connection by its identifier.
+
+ Parameters
+ ----------
+ identifier : str
+ The identifier of the connection in the OpenHEXA backend
+ slug : str
+ Deprecated, same as identifier
+ """
identifier = identifier or slug
if slug is not None:
warn("'slug' is deprecated. Use 'identifier' instead.", DeprecationWarning, stacklevel=2)
@@ -165,7 +231,17 @@ def s3_connection(self, identifier: str = None, slug: str = None) -> S3Connectio
bucket_name=bucket_name,
)
- def gcs_connection(self, identifier: str = None, slug: str = None) -> GCSConnection:
+ @staticmethod
+ def gcs_connection(identifier: str = None, slug: str = None) -> GCSConnection:
+ """Get a Google Cloud Storage connection by its identifier.
+
+ Parameters
+ ----------
+ identifier : str
+ The identifier of the connection in the OpenHEXA backend
+ slug : str
+ Deprecated, same as identifier
+ """
identifier = identifier or slug
if slug is not None:
warn("'slug' is deprecated. Use 'identifier' instead.", DeprecationWarning, stacklevel=2)
@@ -181,7 +257,17 @@ def gcs_connection(self, identifier: str = None, slug: str = None) -> GCSConnect
bucket_name=bucket_name,
)
- def iaso_connection(self, identifier: str = None, slug: str = None) -> IASOConnection:
+ @staticmethod
+ def iaso_connection(identifier: str = None, slug: str = None) -> IASOConnection:
+ """Get a IASO connection by it identifier.
+
+ Parameters
+ ----------
+ identifier : str
+ The identifier of the connection in the OpenHEXA backend
+ slug : str
+ Deprecated, same as identifier
+ """
identifier = identifier or slug
if slug is not None:
warn("'slug' is deprecated. Use 'identifier' instead.", DeprecationWarning, stacklevel=2)
@@ -195,30 +281,48 @@ def iaso_connection(self, identifier: str = None, slug: str = None) -> IASOConne
return IASOConnection(url=url, username=username, password=password)
- def custom_connection(self, identifier: str = None, slug: str = None):
+ @staticmethod
+ def custom_connection(identifier: str = None, slug: str = None) -> CustomConnection:
+ """Get a custom connection by its identifier.
+
+ Parameters
+ ----------
+ identifier : str
+ The identifier of the connection in the OpenHEXA backend
+ slug : str
+ Deprecated, same as identifier
+ """
identifier = identifier or slug
if slug is not None:
warn("'slug' is deprecated. Use 'identifier' instead.", DeprecationWarning, stacklevel=2)
- identifier = identifier.lower()
- env_variable_prefix = stringcase.constcase(identifier)
+ env_variable_prefix = stringcase.constcase(identifier.lower())
fields = {}
for key, value in os.environ.items():
if key.startswith(env_variable_prefix):
field_key = key[len(f"{env_variable_prefix}_") :].lower()
fields[field_key] = value
- dataclass = make_dataclass(identifier, fields.keys())
+ if len(fields) == 0:
+ raise ConnectionDoesNotExist(f'No custom connection for "{identifier}"')
- class CustomConnection(dataclass):
- def __repr__(self):
- return f"CustomConnection(name='{identifier}')"
+ dataclass = make_dataclass(
+ stringcase.pascalcase(identifier), fields.keys(), bases=(CustomConnection,), repr=False
+ )
- return CustomConnection(**fields)
+ return dataclass(**fields)
def create_dataset(self, identifier: str, name: str, description: str):
+ """Create a new dataset."""
raise NotImplementedError("create_dataset is not implemented yet.")
- def get_dataset(self, identifier: str):
+ def get_dataset(self, identifier: str) -> Dataset:
+ """Get a dataset by its identifier.
+
+ Parameters
+ ----------
+ identifier : str
+ The identifier of the dataset in the OpenHEXA backend
+ """
response = graphql(
"""
query getDataset($datasetSlug: String!, $workspaceSlug: String!) {
diff --git a/pyproject.toml b/pyproject.toml
index b5cbb60..c11a9ea 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
[project]
name= "openhexa.sdk"
version = "0.1.33"
-description = "OpenHexa SDK"
+description = "OpenHEXA SDK"
authors = [
{ name = "Bluesquare", email = "dev@bluesquarehub.com"}
@@ -35,7 +35,7 @@ openhexa = "openhexa.cli:app"
[project.optional-dependencies]
-dev = [ "ruff~=0.0.278", "pytest~=7.3.0", "build>=0.10,<1.1", "pytest-cov>=4.0,<4.2" , "black~=23.7.0", "pre-commit"]
+dev = [ "ruff~=0.0.278", "pytest~=7.3.0", "build>=0.10,<1.1", "pytest-cov>=4.0,<4.2" , "pre-commit"]
examples= ["geopandas~=0.12.2", "pandas~=2.0.0", "rasterio~=1.3.6", "rasterstats>=0.18,<0.20", "setuptools>=67.6.1,<69.1.0", "SQLAlchemy~=2.0.9", "psycopg2"]
@@ -53,26 +53,29 @@ namespaces = true
[tool.setuptools]
include-package-data = true
-[tool.black]
-line-length = 120
-exclude = '''
-(
- /(
- \.eggs # exclude a few common directories in the
- | \.git # root of the project
- )/
- | node_modules
-)
-'''
-
[tool.ruff]
line-length = 120
ignore = ["E501"]
+[tool.ruff.lint]
+extend-select = [
+ "D", # pydocstyle
+ "I", # isort
+ "UP", # pyupgrade
+ "ANN", # flake8-annotations
+]
+# Disable all "missing docstrings" and "missing type annotations" rules for now
+# TODO: enable
+ignore = ["ANN001", "ANN002", "ANN003", "ANN101", "ANN102", "ANN201", "ANN202", "ANN204", "ANN205", "ANN401"]
+
[tool.ruff.pycodestyle]
max-doc-length = 120
+[tool.ruff.lint.pydocstyle]
+convention = "numpy" # Accepts: "google", "numpy", or "pep257".
+
+
[tool.coverage.report]
exclude_lines = [
# Have to re-enable the standard pragma
diff --git a/tests/__init__.py b/tests/__init__.py
index e69de29..38bb211 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -0,0 +1 @@
+"""Test package."""
diff --git a/tests/conftest.py b/tests/conftest.py
new file mode 100644
index 0000000..e86c6c3
--- /dev/null
+++ b/tests/conftest.py
@@ -0,0 +1,26 @@
+"""Module-level test fixtures."""
+from importlib import reload
+
+import pytest
+
+import openhexa.sdk
+
+
+@pytest.fixture(scope="function")
+def workspace():
+ """Build workspace fixture."""
+ from openhexa.sdk import workspace as global_workspace
+
+ reload(openhexa.sdk)
+
+ return global_workspace
+
+
+@pytest.fixture(scope="function")
+def current_run():
+ """Build current run fixture."""
+ from openhexa.sdk import current_run as global_current_run
+
+ reload(openhexa.sdk)
+
+ return global_current_run
diff --git a/tests/test_api.py b/tests/test_api.py
index a7913c6..d3c2335 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -1,26 +1,29 @@
+"""API interactions test module."""
+
import configparser
import os
-import pytest
import shutil
-import yaml
import uuid
-
-from click.testing import CliRunner
-from openhexa.cli.api import upload_pipeline
-from openhexa.cli.cli import pipelines_init, pipelines_delete
from pathlib import Path
-
from unittest import mock
from zipfile import ZipFile
+import pytest
+import yaml
+from click.testing import CliRunner
+
+from openhexa.cli.api import upload_pipeline
+from openhexa.cli.cli import pipelines_delete, pipelines_init
+
def test_upload_pipeline():
+ """Test upload API call."""
# to enable zip file creation
config = configparser.ConfigParser()
config["openhexa"] = {"debug": True, "current_workspace": "test_workspace"}
runner = CliRunner()
- runner.invoke(pipelines_init, ["test_pipelines"])
+ runner.invoke(pipelines_init, ["test_pipelines"]) # NOQA
pipeline_dir = Path.cwd() / "test_pipelines"
pipeline_zip_file_dir = Path.cwd() / "pipeline.zip"
@@ -41,17 +44,18 @@ def test_upload_pipeline():
def test_upload_pipeline_custom_files_path():
+ """Test upload API call (custom file path)."""
# to enable zip file creation
config = configparser.ConfigParser()
config["openhexa"] = {"debug": True, "current_workspace": "test_workspace"}
runner = CliRunner()
- runner.invoke(pipelines_init, ["test_pipelines"])
+ runner.invoke(pipelines_init, ["test_pipelines"]) # NOQA
pipeline_dir = Path.cwd() / "test_pipelines"
pipeline_zip_file_dir = Path.cwd() / "pipeline.zip"
(pipeline_dir / Path("data")).mkdir()
- # setup a custom path for files location in workspace.yaml
+ # set up a custom path for files location in workspace.yaml
pipeline_configs = {"files": {"path": "./data"}}
with open(pipeline_dir / Path("workspace.yaml"), "w") as pipeline_configs_file:
@@ -74,6 +78,7 @@ def test_upload_pipeline_custom_files_path():
def test_delete_pipeline_not_in_workspace():
+ """Test delete pipeline (pipeline does not exist)."""
config = configparser.ConfigParser()
config["openhexa"] = {"debug": True, "current_workspace": "test_workspace"}
@@ -83,12 +88,13 @@ def test_delete_pipeline_not_in_workspace():
runner = CliRunner()
mocked_config.return_value = config
mocked_graphql_client.return_value = {"pipelineByCode": None}
- r = runner.invoke(pipelines_delete, ["test_pipelines"], input="test_pipelines")
+ r = runner.invoke(pipelines_delete, ["test_pipelines"], input="test_pipelines") # NOQA
assert r.output == "Pipeline test_pipelines does not exist in workspace test_workspace\n"
def test_delete_pipeline_confirm_code_invalid():
+ """Test delete pipeline with an invalid confirmation code."""
config = configparser.ConfigParser()
config["openhexa"] = {"debug": True, "current_workspace": "test_workspace"}
@@ -105,12 +111,13 @@ def test_delete_pipeline_confirm_code_invalid():
"currentVersion": {"number": 1},
}
}
- r = runner.invoke(pipelines_delete, ["test_pipelines"], input="test_pipeline")
+ r = runner.invoke(pipelines_delete, ["test_pipelines"], input="test_pipeline") # NOQA
# "Pipeline code and confirmation are different
assert r.exit_code == 1
def test_delete_pipeline():
+ """Happy path for delete pipeline API call."""
config = configparser.ConfigParser()
config["openhexa"] = {"debug": True, "current_workspace": "test_workspace"}
@@ -126,6 +133,6 @@ def test_delete_pipeline():
"currentVersion": {"number": 1},
}
mocked_delete_pipeline.return_value = True
- r = runner.invoke(pipelines_delete, ["test_pipelines"], input="test_pipelines")
+ r = runner.invoke(pipelines_delete, ["test_pipelines"], input="test_pipelines") # NOQA
assert r.exit_code == 0
diff --git a/tests/test_dataset.py b/tests/test_dataset.py
index 4878b01..f913fff 100644
--- a/tests/test_dataset.py
+++ b/tests/test_dataset.py
@@ -1,3 +1,5 @@
+"""Dataset test module."""
+
import os
from unittest import TestCase
from unittest.mock import patch
@@ -6,12 +8,15 @@
class DatasetTest(TestCase):
+ """Dataset test class."""
+
@patch.dict(
os.environ,
{"HEXA_WORKSPACE": "workspace-slug", "HEXA_TOKEN": "token", "HEXA_SERVER_URL": "server"},
)
@patch("openhexa.sdk.datasets.dataset.graphql")
def test_create_dataset_version(self, mock_graphql):
+ """Ensure that dataset versions can be created."""
d = Dataset(
id="id",
slug="my-dataset",
diff --git a/tests/test_environment.py b/tests/test_environment.py
index 991f752..ef00060 100644
--- a/tests/test_environment.py
+++ b/tests/test_environment.py
@@ -1,3 +1,5 @@
+"""Collection of tests related to environment variables handling."""
+
import os
from unittest.mock import Mock, patch
@@ -5,14 +7,14 @@
@patch.dict(os.environ, {"WORKSPACE_FILES_PATH": "/workspace/path"})
-def test_workspace_local_files_path():
- from openhexa.sdk.workspaces import workspace
-
+def test_workspace_local_files_path(workspace):
+ """Test workspace files_path property with explicit environment variable."""
assert workspace.files_path == "/workspace/path"
@patch.dict(os.environ, {"HEXA_ENVIRONMENT": "LOCAL_PIPELINE"})
-def test_workspace_pipeline_files_path():
+def test_workspace_pipeline_files_path(workspace):
+ """Test workspace files_path property in local mode."""
from openhexa.sdk.workspaces import workspace
assert workspace.files_path == "/home/hexa/workspace"
@@ -22,21 +24,19 @@ def test_workspace_pipeline_files_path():
os.environ,
{"HEXA_ENVIRONMENT": "CLOUD_PIPELINE", "HEXA_SERVER_URL": "https://test.openhexa.org"},
)
-def test_connected():
- from openhexa.sdk.pipelines import current_run
-
+def test_connected(current_run):
+ """Test run _connected property in CLOUD_PIPELINE mode."""
assert current_run._connected is True
@patch.dict(
os.environ,
{
- "HEXA_ENVIRONMENT": "LOCAL_PIPELINE",
+ "HEXA_ENVIRONMENT": "CLOUD_PIPELINE",
},
)
-def test_not_connected():
- from openhexa.sdk.pipelines import current_run
-
+def test_not_connected_missing_url(current_run):
+ """Test run _connected property in CLOUD_PIPELINE mode (missing HEXA_SERVER_URL)."""
assert current_run._connected is False
@@ -46,9 +46,8 @@ def test_not_connected():
"HEXA_ENVIRONMENT": "LOCAL_PIPELINE",
},
)
-def test_not_connected_missing_url():
- from openhexa.sdk.pipelines import current_run
-
+def test_not_connected(current_run):
+ """Test run _connected property in LOCAL_PIPELINE mode."""
assert current_run._connected is False
@@ -59,6 +58,7 @@ def test_not_connected_missing_url():
},
)
def test_pipeline_local():
+ """Test pipeline _connected property in LOCAL_PIPELINE mode."""
pipeline_func = Mock()
pipeline = Pipeline("code", "pipeline", pipeline_func, [])
@@ -73,6 +73,7 @@ def test_pipeline_local():
},
)
def test_pipeline_pipeline():
+ """Test pipeline _connected property in CLOUD_PIPELINE mode."""
pipeline_func = Mock()
pipeline = Pipeline("code", "pipeline", pipeline_func, [])
diff --git a/tests/test_parameter.py b/tests/test_parameter.py
index 4f6c99f..94f05fd 100644
--- a/tests/test_parameter.py
+++ b/tests/test_parameter.py
@@ -1,36 +1,38 @@
-import pytest
+"""Parameter test module."""
+
import os
+from unittest import mock
+
+import pytest
import stringcase
from openhexa.sdk import (
DHIS2Connection,
+ GCSConnection,
IASOConnection,
PostgreSQLConnection,
- GCSConnection,
S3Connection,
)
-
from openhexa.sdk.pipelines.parameter import (
Boolean,
+ DHIS2ConnectionType,
Float,
FunctionWithParameter,
+ GCSConnectionType,
+ IASOConnectionType,
Integer,
InvalidParameterError,
Parameter,
ParameterValueError,
- StringType,
PostgreSQLConnectionType,
- GCSConnectionType,
S3ConnectionType,
- IASOConnectionType,
- DHIS2ConnectionType,
+ StringType,
parameter,
)
-from unittest import mock
-
def test_parameter_types_normalize():
+ """Check normalization or basic types."""
# StringType
string_parameter_type = StringType()
assert string_parameter_type.normalize("a string") == "a string"
@@ -56,6 +58,7 @@ def test_parameter_types_normalize():
def test_parameter_types_validate():
+ """Sanity checks for basic types validation."""
# StringType
string_parameter_type = StringType()
assert string_parameter_type.validate("a string") == "a string"
@@ -83,6 +86,7 @@ def test_parameter_types_validate():
def test_validate_postgres_connection():
+ """Check PostgreSQL connection validation."""
identifier = "polio-ff3a0d"
env_variable_prefix = stringcase.constcase(identifier)
host = "https://172.17.0.1"
@@ -109,6 +113,7 @@ def test_validate_postgres_connection():
def test_validate_dhis2_connection():
+ """Check DHIS2 connection validation."""
identifier = "dhis2-connection-id"
env_variable_prefix = stringcase.constcase(identifier)
url = "https://test.dhis2.org/"
@@ -123,13 +128,14 @@ def test_validate_dhis2_connection():
f"{env_variable_prefix}_PASSWORD": password,
},
):
- dhsi2_parameter_type = DHIS2ConnectionType()
- assert dhsi2_parameter_type.validate(identifier) == DHIS2Connection(url, username, password)
+ dhis2_parameter_type = DHIS2ConnectionType()
+ assert dhis2_parameter_type.validate(identifier) == DHIS2Connection(url, username, password)
with pytest.raises(ParameterValueError):
- dhsi2_parameter_type.validate(86)
+ dhis2_parameter_type.validate(86)
def test_validate_iaso_connection():
+ """Check IASO connection validation."""
identifier = "iaso-connection-id"
env_variable_prefix = stringcase.constcase(identifier)
url = "https://test.iaso.org/"
@@ -151,6 +157,7 @@ def test_validate_iaso_connection():
def test_validate_gcs_connection():
+ """Check GCS connection validation."""
identifier = "gcs-connection-id"
env_variable_prefix = stringcase.constcase(identifier)
service_account_key = "HqQBxH0BAI3zF7kANUNlGg"
@@ -170,6 +177,7 @@ def test_validate_gcs_connection():
def test_validate_s3_connection():
+ """Check S3 connection validation."""
identifier = "s3-connection-id"
env_variable_prefix = stringcase.constcase(identifier)
secret_access_key = "HqQBxH0BAI3zF7kANUNlGg"
@@ -191,9 +199,10 @@ def test_validate_s3_connection():
def test_parameter_init():
+ """Sanity checks for parameter initialization."""
# Wrong type
with pytest.raises(InvalidParameterError):
- Parameter("arg", type="string")
+ Parameter("arg", type="string") # NOQA
# Wrong code
with pytest.raises(InvalidParameterError):
@@ -233,6 +242,7 @@ def test_parameter_init():
def test_parameter_validate_single():
+ """Base check for single-value validation."""
# required is True by default
parameter_1 = Parameter("arg1", type=str)
assert parameter_1.validate("a valid string") == "a valid string"
@@ -257,6 +267,7 @@ def test_parameter_validate_single():
def test_parameter_validate_multiple():
+ """Test multiple values validation rules."""
# required is True by default
parameter_1 = Parameter("arg1", type=str, multiple=True)
assert parameter_1.validate(["Valid string", "Another valid string"]) == [
@@ -290,6 +301,7 @@ def test_parameter_validate_multiple():
def test_parameter_parameters_spec():
+ """Verify that parameter specifications are built properly and have the proper defaults."""
# required is True by default
an_parameter = Parameter("arg1", type=str, default="yep")
another_parameter = Parameter(
@@ -326,6 +338,8 @@ def test_parameter_parameters_spec():
def test_parameter_decorator():
+ """Ensure that the @parameter decorator behaves as expected (options and defaults)."""
+
@parameter("arg1", type=int)
@parameter(
"arg2",
diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py
index b82aeb7..2dd16dc 100644
--- a/tests/test_pipeline.py
+++ b/tests/test_pipeline.py
@@ -1,14 +1,16 @@
+"""Pipeline test module."""
+
+import os
from unittest.mock import Mock, patch
import pytest
import stringcase
-import os
from openhexa.sdk import (
DHIS2Connection,
+ GCSConnection,
IASOConnection,
PostgreSQLConnection,
- GCSConnection,
S3Connection,
)
from openhexa.sdk.pipelines.parameter import Parameter, ParameterValueError
@@ -16,6 +18,7 @@
def test_pipeline_run_valid_config():
+ """Happy path for pipeline run config."""
pipeline_func = Mock()
parameter_1 = Parameter("arg1", type=str)
parameter_2 = Parameter("arg2", type=str, multiple=True)
@@ -28,6 +31,7 @@ def test_pipeline_run_valid_config():
def test_pipeline_run_invalid_config():
+ """Verify thatinvalid configuration values raise an exception."""
pipeline_func = Mock()
parameter_1 = Parameter("arg1", type=str)
pipeline = Pipeline("code", "pipeline", pipeline_func, [parameter_1])
@@ -36,6 +40,7 @@ def test_pipeline_run_invalid_config():
def test_pipeline_run_extra_config():
+ """Verify that extra (unexpected) configuration values raise an exception."""
pipeline_func = Mock()
parameter_1 = Parameter("arg1", type=str)
pipeline = Pipeline("code", "pipeline", pipeline_func, [parameter_1])
@@ -44,6 +49,7 @@ def test_pipeline_run_extra_config():
def test_pipeline_run_connection_dhis2_parameter_config():
+ """Ensure that DHIS2 connection parameter values are built properly."""
identifier = "dhis2-connection-id"
env_variable_prefix = stringcase.constcase(identifier)
url = "https://test.dhis2.org/"
@@ -69,6 +75,7 @@ def test_pipeline_run_connection_dhis2_parameter_config():
def test_pipeline_run_connection_iaso_parameter_config():
+ """Ensure that IASO connection parameter values are built properly."""
identifier = "iaso-connection-id"
env_variable_prefix = stringcase.constcase(identifier)
url = "https://test.iaso.org/"
@@ -93,6 +100,7 @@ def test_pipeline_run_connection_iaso_parameter_config():
def test_pipeline_run_connection_gcs_parameter_config():
+ """Ensure that GCS connection parameter values are built properly."""
identifier = "gcs-connection-id"
env_variable_prefix = stringcase.constcase(identifier)
service_account_key = "HqQBxH0BAI3zF7kANUNlGg"
@@ -115,6 +123,7 @@ def test_pipeline_run_connection_gcs_parameter_config():
def test_pipeline_run_connection_s3_parameter_config():
+ """Ensure that S3 connection parameter values are built properly."""
identifier = "s3-connection-id"
env_variable_prefix = stringcase.constcase(identifier)
secret_access_key = "HqQBxH0BAI3zF7kANUNlGg"
@@ -141,6 +150,7 @@ def test_pipeline_run_connection_s3_parameter_config():
def test_pipeline_run_connection_postgres_parameter_config():
+ """Ensure that postgreSQL connection parameter values are built properly."""
identifier = "postgres-connection-id"
env_variable_prefix = stringcase.constcase(identifier)
host = "https://127.0.0.1"
@@ -173,6 +183,7 @@ def test_pipeline_run_connection_postgres_parameter_config():
def test_pipeline_parameters_spec():
+ """Base checks for parameter specs building."""
pipeline_func = Mock()
parameter_1 = Parameter("arg1", type=str)
parameter_2 = Parameter("arg2", type=str, multiple=True)
diff --git a/tests/test_workspace.py b/tests/test_workspace.py
index 154fcbd..41cd7b5 100644
--- a/tests/test_workspace.py
+++ b/tests/test_workspace.py
@@ -1,22 +1,26 @@
+"""Workspace test module."""
+
import os
+import re
from tempfile import mkdtemp
+from unittest import mock
import pytest
-import re
import stringcase
-from openhexa.sdk.workspaces.workspace import ConnectionDoesNotExist, workspace
-from unittest import mock
+from openhexa.sdk.workspaces.workspace import ConnectionDoesNotExist
-def test_workspace_files_path(monkeypatch):
+def test_workspace_files_path(monkeypatch, workspace):
+ """Basic checks for the Workspace.files_path() method."""
assert workspace.files_path == "/home/hexa/workspace"
monkeypatch.setenv("WORKSPACE_FILES_PATH", "/Users/John/openhexa/project-1/workspace")
assert workspace.files_path == "/Users/John/openhexa/project-1/workspace"
-def test_workspace_tmp_path(monkeypatch):
+def test_workspace_tmp_path(monkeypatch, workspace):
+ """Basic checks for the Workspace.tmp_path() method."""
assert workspace.tmp_path == "/home/hexa/tmp"
mock_tmp_path = mkdtemp()
@@ -24,13 +28,15 @@ def test_workspace_tmp_path(monkeypatch):
assert workspace.tmp_path == mock_tmp_path
-def test_workspace_dhis2_connection_not_exist():
+def test_workspace_dhis2_connection_not_exist(workspace):
+ """Does not exist test case for DHIS2 connections."""
identifier = "polio-ff3a0d"
with pytest.raises(ConnectionDoesNotExist):
workspace.dhis2_connection(identifier=identifier)
-def test_workspace_dhis2_connection():
+def test_workspace_dhis2_connection(workspace):
+ """Base test case for DHIS2 connections."""
identifier = "polio-ff3a0d"
env_variable_prefix = stringcase.constcase(identifier)
url = "https://test.dhis2.org/"
@@ -48,17 +54,19 @@ def test_workspace_dhis2_connection():
assert dhis2_connection.url == url
assert dhis2_connection.username == username
assert dhis2_connection.password == password
- assert re.search("password", repr(dhis2_connection)) is None
- assert re.search("password", str(dhis2_connection)) is None
+ assert re.search("password", repr(dhis2_connection), re.IGNORECASE) is None
+ assert re.search("password", str(dhis2_connection), re.IGNORECASE) is None
-def test_workspace_postgresql_connection_not_exist():
+def test_workspace_postgresql_connection_not_exist(workspace):
+ """Does not exist test case for PostgreSQL connections."""
identifier = "polio-ff3a0d"
with pytest.raises(ConnectionDoesNotExist):
workspace.postgresql_connection(identifier=identifier)
-def test_workspace_postgresql_connection():
+def test_workspace_postgresql_connection(workspace):
+ """Base test case for PostgreSQL connections."""
identifier = "polio-ff3a0d"
env_variable_prefix = stringcase.constcase(identifier)
host = "https://172.17.0.1"
@@ -84,17 +92,19 @@ def test_workspace_postgresql_connection():
assert postgres_connection.port == int(port)
assert postgres_connection.database_name == database_name
assert postgres_connection.url == url
- assert re.search("password", repr(postgres_connection)) is None
- assert re.search("password", str(postgres_connection)) is None
+ assert re.search("password", repr(postgres_connection), re.IGNORECASE) is None
+ assert re.search("password", str(postgres_connection), re.IGNORECASE) is None
-def test_workspace_S3_connection_not_exist():
+def test_workspace_S3_connection_not_exist(workspace):
+ """Does not exist test case for S3 connections."""
identifier = "polio-ff3a0d"
with pytest.raises(ConnectionDoesNotExist):
workspace.s3_connection(identifier=identifier)
-def test_workspace_s3_connection():
+def test_workspace_s3_connection(workspace):
+ """Base test case for S3 connections."""
identifier = "polio-ff3a0d"
env_variable_prefix = stringcase.constcase(identifier)
secret_access_key = "HqQBxH0BAI3zF7kANUNlGg"
@@ -113,19 +123,21 @@ def test_workspace_s3_connection():
assert s3_connection.secret_access_key == secret_access_key
assert s3_connection.access_key_id == access_key_id
assert s3_connection.bucket_name == bucket_name
- assert re.search("secret_access_key", repr(s3_connection)) is None
- assert re.search("secret_access_key", str(s3_connection)) is None
- assert re.search("access_key_id", repr(s3_connection)) is None
- assert re.search("access_key_id", str(s3_connection)) is None
+ assert re.search("secret_access_key", repr(s3_connection), re.IGNORECASE) is None
+ assert re.search("secret_access_key", str(s3_connection), re.IGNORECASE) is None
+ assert re.search("access_key_id", repr(s3_connection), re.IGNORECASE) is None
+ assert re.search("access_key_id", str(s3_connection), re.IGNORECASE) is None
-def test_workspace_gcs_connection_not_exist():
+def test_workspace_gcs_connection_not_exist(workspace):
+ """Does not exist test case for GCS connections."""
identifier = "polio-ff3a0d"
with pytest.raises(ConnectionDoesNotExist):
workspace.gcs_connection(identifier=identifier)
-def test_workspace_gcs_connection():
+def test_workspace_gcs_connection(workspace):
+ """Base test case for GCS connections."""
identifier = "polio-ff3a0d"
env_variable_prefix = stringcase.constcase(identifier)
service_account_key = "HqQBxH0BAI3zF7kANUNlGg"
@@ -141,17 +153,19 @@ def test_workspace_gcs_connection():
s3_connection = workspace.gcs_connection(identifier=identifier)
assert s3_connection.service_account_key == service_account_key
assert s3_connection.bucket_name == bucket_name
- assert re.search("service_account_key", repr(s3_connection)) is None
- assert re.search("service_account_key", str(s3_connection)) is None
+ assert re.search("service_account_key", repr(s3_connection), re.IGNORECASE) is None
+ assert re.search("service_account_key", str(s3_connection), re.IGNORECASE) is None
-def test_workspace_iaso_connection_not_exist():
+def test_workspace_iaso_connection_not_exist(workspace):
+ """Does not exist test case for IASO connections."""
identifier = "iaso-account"
with pytest.raises(ConnectionDoesNotExist):
workspace.iaso_connection(identifier=identifier)
-def test_workspace_iaso_connection():
+def test_workspace_iaso_connection(workspace):
+ """Base test case for IASO connections."""
identifier = "iaso-account"
env_variable_prefix = stringcase.constcase(identifier)
url = "https://test.iaso.org/"
@@ -170,11 +184,19 @@ def test_workspace_iaso_connection():
assert iaso_connection.url == url
assert iaso_connection.username == username
assert iaso_connection.password == password
- assert re.search("password", repr(iaso_connection)) is None
- assert re.search("password", str(iaso_connection)) is None
+ assert re.search("password", repr(iaso_connection), re.IGNORECASE) is None
+ assert re.search("password", str(iaso_connection), re.IGNORECASE) is None
+
+
+def test_workspace_custom_connection_not_exist(workspace):
+ """Does not exist test case for custom connections."""
+ identifier = "custom-stuff"
+ with pytest.raises(ConnectionDoesNotExist):
+ workspace.custom_connection(identifier=identifier)
-def test_workspace_custom_connection():
+def test_workspace_custom_connection(workspace):
+ """Base test case for custom connections."""
identifier = "my_connection"
env_variable_prefix = stringcase.constcase(identifier)
username = "kaggle_username"
@@ -190,11 +212,12 @@ def test_workspace_custom_connection():
custom_connection = workspace.custom_connection(identifier=identifier)
assert custom_connection.username == username
assert custom_connection.password == password
- assert re.search(f"{env_variable_prefix}_PASSWORD", repr(custom_connection)) is None
- assert re.search(f"{env_variable_prefix}_PASSWORD", str(custom_connection)) is None
+ assert re.search("username", repr(custom_connection), re.IGNORECASE) is None
+ assert re.search("password", str(custom_connection), re.IGNORECASE) is None
-def test_connection_by_slug_warning():
+def test_connection_by_slug_warning(workspace):
+ """Ensure that using the slug keyword argument when getting a connection generates a deprecation warning."""
identifier = "polio-ff3a0d"
env_variable_prefix = stringcase.constcase(identifier)
url = "https://test.dhis2.org/"
@@ -214,7 +237,8 @@ def test_connection_by_slug_warning():
assert workspace.dhis2_connection(identifier=identifier).url == url
-def test_connection_various_case():
+def test_connection_various_case(workspace):
+ """Ensure that identifiers used when getting connections are case-insensitive."""
env_variable_prefix = stringcase.constcase("polio-123")
url = "https://test.dhis2.org/"
username = "dhis2"