-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8665d48
commit bbdbb97
Showing
8 changed files
with
576 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,3 +13,4 @@ def cli() -> None: | |
|
||
|
||
cli.add_command(commands.sessions) | ||
cli.add_command(commands.tasks) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from .sessions import sessions | ||
from .tasks import tasks | ||
|
||
|
||
__all__ = ["sessions"] | ||
__all__ = ["sessions", "tasks"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,258 @@ | ||
import grpc | ||
import rich_click as click | ||
|
||
from datetime import timedelta | ||
from typing import List, Tuple, Union | ||
|
||
from armonik.client.tasks import ArmoniKTasks | ||
from armonik.common import Task, TaskStatus, TaskDefinition, TaskOptions, Direction | ||
from armonik.common.filter import TaskFilter, Filter | ||
|
||
from armonik_cli.core import console, base_command | ||
from armonik_cli.core.params import KeyValuePairParam, TimeDeltaParam, FilterParam, FieldParam | ||
|
||
TASKS_TABLE_COLS = [("ID", "Id"), ("Status", "Status"), ("CreatedAt", "CreatedAt")] | ||
|
||
|
||
@click.group(name="task") | ||
def tasks() -> None: | ||
"""Manage cluster's tasks.""" | ||
pass | ||
|
||
|
||
@tasks.command() | ||
@click.argument("session-id", required=True, type=str) | ||
@click.option( | ||
"-f", | ||
"--filter", | ||
"filter_with", | ||
type=FilterParam("Task"), | ||
required=False, | ||
help="An expression to filter the listed tasks with.", | ||
metavar="FILTER EXPR", | ||
) | ||
@click.option( | ||
"--sort_by", type=FieldParam("Task"), required=False, help="Attribute of task to sort with." | ||
) | ||
@click.option( | ||
"--sort_order", | ||
type=click.Choice(["asc", "desc"], case_sensitive=False), | ||
default="asc", | ||
required=False, | ||
help="Whether to sort by ascending or by descending order.", | ||
) | ||
@click.option("--page", default=-1) | ||
@click.option("--page-size", default=100) | ||
@base_command | ||
def list( | ||
endpoint: str, | ||
output: str, | ||
session_id: str, | ||
filter_with: Union[TaskFilter, None], | ||
sort_by: Filter, | ||
sort_order: str, | ||
page: int, | ||
page_size: int, | ||
debug: bool, | ||
) -> None: | ||
"List all tasks in a session. This command takes the session id as an argument." | ||
with grpc.insecure_channel(endpoint) as channel: | ||
tasks_client = ArmoniKTasks(channel) | ||
curr_page = page if page > 0 else 0 | ||
tasks_list = [] | ||
while True: | ||
total, curr_tasks_list = tasks_client.list_tasks( | ||
task_filter=(Task.session_id == session_id) & filter_with | ||
if filter_with is not None | ||
else (Task.session_id == session_id), | ||
sort_field=Task.id if sort_by is None else sort_by, | ||
sort_direction=Direction.ASC | ||
if sort_order.capitalize() == "ASC" | ||
else Direction.DESC, | ||
page=curr_page, | ||
page_size=page_size, | ||
) | ||
tasks_list += curr_tasks_list | ||
|
||
if page >= 0 or total < page_size or curr_page > (total // page_size): | ||
break | ||
curr_page += 1 | ||
|
||
if total > 0: | ||
tasks_list = [_clean_up_status(task) for task in tasks_list] | ||
console.formatted_print(tasks_list, format=output, table_cols=TASKS_TABLE_COLS) | ||
|
||
|
||
@tasks.command() | ||
@click.argument("task-ids", type=str, nargs=-1, required=True) | ||
@base_command | ||
def get(endpoint: str, output: str, task_ids: List[str], debug: bool): | ||
"""Get a detailed overview of set of tasks given their ids.""" | ||
with grpc.insecure_channel(endpoint) as channel: | ||
tasks_client = ArmoniKTasks(channel) | ||
tasks = [] | ||
for task_id in task_ids: | ||
task = tasks_client.get_task(task_id) | ||
task = _clean_up_status(task) | ||
tasks.append(task) | ||
console.formatted_print(tasks, format=output, table_cols=TASKS_TABLE_COLS) | ||
|
||
|
||
@tasks.command() | ||
@click.argument("task-ids", type=str, nargs=-1, required=True) | ||
@base_command | ||
def cancel(endpoint: str, output: str, task_ids: List[str], debug: bool): | ||
"Cancel tasks given their ids. (They don't have to be in the same session necessarily)." | ||
with grpc.insecure_channel(endpoint) as channel: | ||
tasks_client = ArmoniKTasks(channel) | ||
tasks_client.cancel_tasks(task_ids) | ||
# We do try to catch errors but we're sort of just letting the API take the wheel on this one | ||
|
||
|
||
@tasks.command() | ||
@click.option( | ||
"--session-id", | ||
type=str, | ||
required=True, | ||
help="Id of the session to create the task in.", | ||
metavar="SESSION_ID", | ||
) | ||
@click.option( | ||
"--payload-id", | ||
type=str, | ||
required=True, | ||
help="Id of the payload to associated to the task.", | ||
metavar="PAYLOAD_ID", | ||
) | ||
@click.option( | ||
"--expected-outputs", | ||
multiple=True, | ||
required=True, | ||
help="List of the ids of the task's outputs.", | ||
metavar="EXPECTED_OUTPUTS", | ||
) | ||
@click.option( | ||
"--data-dependencies", | ||
multiple=True, | ||
help="List of the ids of the task's data dependencies.", | ||
metavar="DATA_DEPENDENCIES", | ||
) | ||
@click.option( | ||
"--max-retries", | ||
type=int, | ||
default=None, | ||
help="Maximum default number of execution attempts for this task.", | ||
metavar="NUM_RETRIES", | ||
) | ||
@click.option( | ||
"--max-duration", | ||
type=TimeDeltaParam(), | ||
default=None, | ||
help="Maximum default task execution time (format HH:MM:SS.MS).", | ||
metavar="DURATION", | ||
) | ||
@click.option("--priority", default=None, type=int, help="Task priority.", metavar="PRIORITY") | ||
@click.option( | ||
"--partition-id", | ||
type=str, | ||
help="Partition to run the task in.", | ||
metavar="PARTITION", | ||
) | ||
@click.option( | ||
"--application-name", | ||
type=str, | ||
required=False, | ||
help="Application name for this task.", | ||
metavar="NAME", | ||
) | ||
@click.option( | ||
"--application-version", | ||
type=str, | ||
required=False, | ||
help="Application version for this task.", | ||
metavar="VERSION", | ||
) | ||
@click.option( | ||
"--application-namespace", | ||
type=str, | ||
required=False, | ||
help="Application namespace for this task.", | ||
metavar="NAMESPACE", | ||
) | ||
@click.option( | ||
"--application-service", | ||
type=str, | ||
required=False, | ||
help="Application service for this task.", | ||
metavar="SERVICE", | ||
) | ||
@click.option("--engine-type", type=str, required=False, help="Engine type.", metavar="ENGINE_TYPE") | ||
@click.option( | ||
"--options", | ||
type=KeyValuePairParam(), | ||
default=None, | ||
multiple=True, | ||
help="Additional task options.", | ||
metavar="KEY=VALUE", | ||
) | ||
@base_command | ||
def create( | ||
endpoint: str, | ||
output: str, | ||
session_id: str, | ||
payload_id: str, | ||
expected_outputs: List[str], | ||
data_dependencies: Union[List[str], None], | ||
max_retries: Union[int, None], | ||
max_duration: Union[timedelta, None], | ||
priority: Union[int, None], | ||
partition_id: Union[str, None], | ||
application_name: Union[str, None], | ||
application_version: Union[str, None], | ||
application_namespace: Union[str, None], | ||
application_service: Union[str, None], | ||
engine_type: Union[str, None], | ||
options: Union[List[Tuple[str, str]], None], | ||
debug: bool, | ||
): | ||
"""Create a task.""" | ||
with grpc.insecure_channel(endpoint) as channel: | ||
tasks_client = ArmoniKTasks(channel) | ||
task_options = None | ||
if all([max_duration, priority, max_retries]): | ||
task_options = TaskOptions( | ||
max_duration, | ||
priority, | ||
max_retries, | ||
partition_id, | ||
application_name, | ||
application_version, | ||
application_namespace, | ||
application_service, | ||
engine_type, | ||
options, | ||
) | ||
elif any([max_duration, priority, max_retries]): | ||
click.echo( | ||
click.style( | ||
"If you want to pass in additional task options please provide all three (max duration, priority, max retries)", | ||
"red", | ||
) | ||
) | ||
return | ||
task_definition = TaskDefinition( | ||
payload_id, expected_outputs, data_dependencies, task_options | ||
) | ||
submitted_tasks = tasks_client.submit_tasks(session_id, [task_definition]) | ||
|
||
console.formatted_print( | ||
[_clean_up_status(t) for t in submitted_tasks], | ||
format=output, | ||
table_cols=TASKS_TABLE_COLS, | ||
) | ||
|
||
|
||
def _clean_up_status(task: Task) -> Task: | ||
task.status = TaskStatus(task.status).name.split("_")[-1].capitalize() | ||
task.output = task.output.error if task.output else None | ||
return task |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.