From 01e7c64bc44217356b6ec73e2ca4909c4693ad30 Mon Sep 17 00:00:00 2001 From: pvanliefland Date: Thu, 28 Dec 2023 11:16:02 +0100 Subject: [PATCH] doc: add docstrings for public methods --- openhexa/sdk/pipelines/run.py | 13 ++++++ openhexa/sdk/pipelines/task.py | 63 ++++++++++++++------------- openhexa/sdk/workspaces/connection.py | 5 +++ openhexa/sdk/workspaces/workspace.py | 1 + tests/test_dataset.py | 1 + 5 files changed, 53 insertions(+), 30 deletions(-) diff --git a/openhexa/sdk/pipelines/run.py b/openhexa/sdk/pipelines/run.py index 976ca43..7022d8e 100644 --- a/openhexa/sdk/pipelines/run.py +++ b/openhexa/sdk/pipelines/run.py @@ -25,6 +25,10 @@ def tmp_path(self): return Path("~/tmp/") def add_file_output(self, path: str): + """Record a run output for a file creation operation. + + This output will be visible in the web interface, on the pipeline run page. + """ stripped_path = path.replace(workspace.files_path, "") name = stripped_path.strip("/") if self._connected: @@ -45,6 +49,10 @@ def add_file_output(self, path: str): print(f"Sending output with path {stripped_path}") def add_database_output(self, table_name: str): + """Record a run output for a database operation. + + This output will be visible in the web interface, on the pipeline run page. + """ if self._connected: graphql( """ @@ -63,18 +71,23 @@ def add_database_output(self, table_name: str): print(f"Sending output with table_name {table_name}") def log_debug(self, message: str): + """Log a message with the DEBUG priority.""" self._log_message("DEBUG", message) def log_info(self, message: str): + """Log a message with the INFO priority.""" self._log_message("INFO", message) def log_warning(self, message: str): + """Log a message with the WARNING priority.""" self._log_message("WARNING", message) def log_error(self, message: str): + """Log a message with the ERROR priority.""" self._log_message("ERROR", message) def log_critical(self, message: str): + """Log a message with the CRITICAL priority.""" self._log_message("CRITICAL", message) def _log_message( diff --git a/openhexa/sdk/pipelines/task.py b/openhexa/sdk/pipelines/task.py index 37c6e09..820a80b 100644 --- a/openhexa/sdk/pipelines/task.py +++ b/openhexa/sdk/pipelines/task.py @@ -12,7 +12,10 @@ class TaskCom: - """Lightweight data transfer object allowing tasks to communicate.""" + """Lightweight data transfer object allowing tasks to communicate. + + TaskCom instances also allow us to build the pipeline dependency graph. + """ def __init__(self, task): self.result = task.result @@ -38,27 +41,11 @@ def __init__(self, function: typing.Callable): self.active = False self.pooled = False - def __call__(self, *task_args, **task_kwargs): - self.active = True # uncalled tasks will be skipped - # check that all inputs are tasks - self.task_args = task_args - self.task_kwargs = task_kwargs - return self + def is_ready(self) -> bool: + """Determine whether the task is ready to be run. - def __repr__(self): - return self.name - - def get_node_inputs(self): - inputs = [] - for a in self.task_args: - if issubclass(type(a), Task): - inputs.append(a) - for k, a in self.task_kwargs.items(): - if issubclass(type(a), Task): - inputs.append(a) - return inputs - - def is_ready(self): + This involves checking whether tasks higher up in the dependency graph have been executed. + """ if not self.active: return False @@ -71,24 +58,32 @@ def is_ready(self): return True if self.end_time is None else False - def get_tasks_ready(self): + def get_ready_tasks(self) -> list[Task]: + """Find and return all tasks that can be launched at this point in time.""" tasks = [] for a in self.task_args: if issubclass(type(a), Task): if a.is_ready(): tasks.append(a) else: - tasks += a.get_tasks_ready() + tasks += a.get_ready_tasks() for k, a in self.task_kwargs.items(): if issubclass(type(a), Task): if a.is_ready(): tasks.append(a) else: - tasks += a.get_tasks_ready() + tasks += a.get_ready_tasks() return list(set(tasks)) - def run(self): + def run(self) -> TaskCom: + """Run the task. + + Returns + ------- + TaskCom + A TaskCom instance which can in turn be passed to other tasks. + """ if self.end_time: # already executed, return previous result return self.result @@ -118,10 +113,17 @@ def run(self): # done! return TaskCom(self) - def stateless_run(self): - self.result = None - self.start_time, self.end_time = None, None - return self.run() + def __call__(self, *task_args, **task_kwargs): + """Wrap the task with args and kwargs and return it.""" + self.active = True # uncalled tasks will be skipped + # check that all inputs are tasks + self.task_args = task_args + self.task_kwargs = task_kwargs + + return self + + def __repr__(self): + return self.name class PipelineWithTask: @@ -135,7 +137,8 @@ def __init__( self.function = function self.pipeline = pipeline - def __call__(self, *task_args, **task_kwargs): + def __call__(self, *task_args, **task_kwargs) -> Task: + """Attach the new task to the decorated pipeline and return it.""" task = Task(self.function)(*task_args, **task_kwargs) self.pipeline.tasks.append(task) return task diff --git a/openhexa/sdk/workspaces/connection.py b/openhexa/sdk/workspaces/connection.py index 0d1a7d2..b7fc2a4 100644 --- a/openhexa/sdk/workspaces/connection.py +++ b/openhexa/sdk/workspaces/connection.py @@ -36,6 +36,11 @@ def __repr__(self): @property def url(self): + """Provide a URL to the PostgreSQL database. + + The URL follows the official PostgreSQL specification. + (See https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING for more information) + """ return f"postgresql://{self.username}:{self.password}" f"@{self.host}:{self.port}/{self.database_name}" diff --git a/openhexa/sdk/workspaces/workspace.py b/openhexa/sdk/workspaces/workspace.py index c1e0003..9b13db5 100644 --- a/openhexa/sdk/workspaces/workspace.py +++ b/openhexa/sdk/workspaces/workspace.py @@ -313,6 +313,7 @@ def __repr__(self): return CustomConnection(**fields) def create_dataset(self, identifier: str, name: str, description: str): + """Create a new dataset.""" raise NotImplementedError("create_dataset is not implemented yet.") def get_dataset(self, identifier: str) -> Dataset: diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 4a4eb26..f913fff 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -16,6 +16,7 @@ class DatasetTest(TestCase): ) @patch("openhexa.sdk.datasets.dataset.graphql") def test_create_dataset_version(self, mock_graphql): + """Ensure that dataset versions can be created.""" d = Dataset( id="id", slug="my-dataset",