Skip to content

Commit

Permalink
Added task commands
Browse files Browse the repository at this point in the history
  • Loading branch information
AncientPatata committed Jan 9, 2025
1 parent 8665d48 commit bbdbb97
Show file tree
Hide file tree
Showing 8 changed files with 576 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/armonik_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ def cli() -> None:


cli.add_command(commands.sessions)
cli.add_command(commands.tasks)
4 changes: 2 additions & 2 deletions src/armonik_cli/commands/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .sessions import sessions
from .tasks import tasks


__all__ = ["sessions"]
__all__ = ["sessions", "tasks"]
258 changes: 258 additions & 0 deletions src/armonik_cli/commands/tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
import grpc
import rich_click as click

from datetime import timedelta
from typing import List, Tuple, Union

from armonik.client.tasks import ArmoniKTasks
from armonik.common import Task, TaskStatus, TaskDefinition, TaskOptions, Direction
from armonik.common.filter import TaskFilter, Filter

from armonik_cli.core import console, base_command
from armonik_cli.core.params import KeyValuePairParam, TimeDeltaParam, FilterParam, FieldParam

TASKS_TABLE_COLS = [("ID", "Id"), ("Status", "Status"), ("CreatedAt", "CreatedAt")]


@click.group(name="task")
def tasks() -> None:
"""Manage cluster's tasks."""
pass


@tasks.command()
@click.argument("session-id", required=True, type=str)
@click.option(
"-f",
"--filter",
"filter_with",
type=FilterParam("Task"),
required=False,
help="An expression to filter the listed tasks with.",
metavar="FILTER EXPR",
)
@click.option(
"--sort_by", type=FieldParam("Task"), required=False, help="Attribute of task to sort with."
)
@click.option(
"--sort_order",
type=click.Choice(["asc", "desc"], case_sensitive=False),
default="asc",
required=False,
help="Whether to sort by ascending or by descending order.",
)
@click.option("--page", default=-1)
@click.option("--page-size", default=100)
@base_command
def list(
endpoint: str,
output: str,
session_id: str,
filter_with: Union[TaskFilter, None],
sort_by: Filter,
sort_order: str,
page: int,
page_size: int,
debug: bool,
) -> None:
"List all tasks in a session. This command takes the session id as an argument."
with grpc.insecure_channel(endpoint) as channel:
tasks_client = ArmoniKTasks(channel)
curr_page = page if page > 0 else 0
tasks_list = []
while True:
total, curr_tasks_list = tasks_client.list_tasks(
task_filter=(Task.session_id == session_id) & filter_with
if filter_with is not None
else (Task.session_id == session_id),
sort_field=Task.id if sort_by is None else sort_by,
sort_direction=Direction.ASC
if sort_order.capitalize() == "ASC"
else Direction.DESC,
page=curr_page,
page_size=page_size,
)
tasks_list += curr_tasks_list

if page >= 0 or total < page_size or curr_page > (total // page_size):
break
curr_page += 1

if total > 0:
tasks_list = [_clean_up_status(task) for task in tasks_list]
console.formatted_print(tasks_list, format=output, table_cols=TASKS_TABLE_COLS)


@tasks.command()
@click.argument("task-ids", type=str, nargs=-1, required=True)
@base_command
def get(endpoint: str, output: str, task_ids: List[str], debug: bool):
"""Get a detailed overview of set of tasks given their ids."""
with grpc.insecure_channel(endpoint) as channel:
tasks_client = ArmoniKTasks(channel)
tasks = []
for task_id in task_ids:
task = tasks_client.get_task(task_id)
task = _clean_up_status(task)
tasks.append(task)
console.formatted_print(tasks, format=output, table_cols=TASKS_TABLE_COLS)


@tasks.command()
@click.argument("task-ids", type=str, nargs=-1, required=True)
@base_command
def cancel(endpoint: str, output: str, task_ids: List[str], debug: bool):
"Cancel tasks given their ids. (They don't have to be in the same session necessarily)."
with grpc.insecure_channel(endpoint) as channel:
tasks_client = ArmoniKTasks(channel)
tasks_client.cancel_tasks(task_ids)
# We do try to catch errors but we're sort of just letting the API take the wheel on this one


@tasks.command()
@click.option(
"--session-id",
type=str,
required=True,
help="Id of the session to create the task in.",
metavar="SESSION_ID",
)
@click.option(
"--payload-id",
type=str,
required=True,
help="Id of the payload to associated to the task.",
metavar="PAYLOAD_ID",
)
@click.option(
"--expected-outputs",
multiple=True,
required=True,
help="List of the ids of the task's outputs.",
metavar="EXPECTED_OUTPUTS",
)
@click.option(
"--data-dependencies",
multiple=True,
help="List of the ids of the task's data dependencies.",
metavar="DATA_DEPENDENCIES",
)
@click.option(
"--max-retries",
type=int,
default=None,
help="Maximum default number of execution attempts for this task.",
metavar="NUM_RETRIES",
)
@click.option(
"--max-duration",
type=TimeDeltaParam(),
default=None,
help="Maximum default task execution time (format HH:MM:SS.MS).",
metavar="DURATION",
)
@click.option("--priority", default=None, type=int, help="Task priority.", metavar="PRIORITY")
@click.option(
"--partition-id",
type=str,
help="Partition to run the task in.",
metavar="PARTITION",
)
@click.option(
"--application-name",
type=str,
required=False,
help="Application name for this task.",
metavar="NAME",
)
@click.option(
"--application-version",
type=str,
required=False,
help="Application version for this task.",
metavar="VERSION",
)
@click.option(
"--application-namespace",
type=str,
required=False,
help="Application namespace for this task.",
metavar="NAMESPACE",
)
@click.option(
"--application-service",
type=str,
required=False,
help="Application service for this task.",
metavar="SERVICE",
)
@click.option("--engine-type", type=str, required=False, help="Engine type.", metavar="ENGINE_TYPE")
@click.option(
"--options",
type=KeyValuePairParam(),
default=None,
multiple=True,
help="Additional task options.",
metavar="KEY=VALUE",
)
@base_command
def create(
endpoint: str,
output: str,
session_id: str,
payload_id: str,
expected_outputs: List[str],
data_dependencies: Union[List[str], None],
max_retries: Union[int, None],
max_duration: Union[timedelta, None],
priority: Union[int, None],
partition_id: Union[str, None],
application_name: Union[str, None],
application_version: Union[str, None],
application_namespace: Union[str, None],
application_service: Union[str, None],
engine_type: Union[str, None],
options: Union[List[Tuple[str, str]], None],
debug: bool,
):
"""Create a task."""
with grpc.insecure_channel(endpoint) as channel:
tasks_client = ArmoniKTasks(channel)
task_options = None
if all([max_duration, priority, max_retries]):
task_options = TaskOptions(
max_duration,
priority,
max_retries,
partition_id,
application_name,
application_version,
application_namespace,
application_service,
engine_type,
options,
)
elif any([max_duration, priority, max_retries]):
click.echo(
click.style(
"If you want to pass in additional task options please provide all three (max duration, priority, max retries)",
"red",
)
)
return
task_definition = TaskDefinition(
payload_id, expected_outputs, data_dependencies, task_options
)
submitted_tasks = tasks_client.submit_tasks(session_id, [task_definition])

console.formatted_print(
[_clean_up_status(t) for t in submitted_tasks],
format=output,
table_cols=TASKS_TABLE_COLS,
)


def _clean_up_status(task: Task) -> Task:
task.status = TaskStatus(task.status).name.split("_")[-1].capitalize()
task.output = task.output.error if task.output else None
return task
4 changes: 4 additions & 0 deletions src/armonik_cli/core/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ def wrapper(*args, **kwargs):

if status_code == grpc.StatusCode.NOT_FOUND:
raise NotFoundError(error_details)
elif status_code == grpc.StatusCode.INTERNAL:
raise InternalError(f"An nternal exception has occured:\n{error_details}")
elif status_code == grpc.StatusCode.UNKNOWN:
raise InternalError(f"An unknown exception has occured:\n{error_details}")
else:
raise InternalError("An internal fatal error occured.")
except Exception:
Expand Down
53 changes: 53 additions & 0 deletions src/armonik_cli/core/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from datetime import timedelta
from typing import cast, Tuple, Union

from armonik import common
from armonik.common import Filter
from armonik.common.filter.filter import FType
from lark.exceptions import VisitError, UnexpectedInput

from armonik_cli.utils import parse_time_delta
Expand Down Expand Up @@ -135,3 +137,54 @@ def convert(
self.fail(f"Filter syntax error: {error.get_context(value, span=40)}.", param, ctx)
except VisitError as error:
self.fail(str(error.orig_exc), param, ctx)


class FieldParam(click.ParamType):
"""
A custom Click parameter type that validates a field name against the possible fields of a base structure.
Attributes:
name: The name of the parameter type, used by Click.
"""

name = "field"

def __init__(self, base_struct: str) -> None:
"""
Initializes the FieldParam instance and gets the fields from the provided base structure.
Args:
base_struct: The base structure name to validate fields against (e.g., "Task", "Session").
"""
super().__init__()
self.base_struct = base_struct.capitalize()
cls = getattr(common.filter, f"{self.base_struct}Filter")
self.possible_fields = [
field
for field in cls._fields.keys()
if cls._fields[field][0] != FType.NA and cls._fields[field][0] != FType.UNKNOWN
]

def convert(
self, value: str, param: Union[click.Parameter, None], ctx: Union[click.Context, None]
) -> Filter:
"""
Converts the provided value into a field after checking if said value is supported.
Args:
value: The input field name to validate.
param: The parameter object passed by Click.
ctx: The context in which the parameter is being used.
Returns:
A field object.
Raises:
click.BadParameter: If the input field is not valid.
"""
if value not in self.possible_fields:
self.fail(
f"{self.base_struct} has no attribute with the name {value}, only valid choices are {','.join(self.possible_fields)}"
)
cls = getattr(common, self.base_struct)
return getattr(cls, value)
4 changes: 2 additions & 2 deletions src/armonik_cli/core/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from datetime import datetime, timedelta
from typing import Dict, Union, Any

from armonik.common import Session, TaskOptions
from armonik.common import Session, TaskOptions, Task
from google._upb._message import ScalarMapContainer


Expand All @@ -16,7 +16,7 @@ class CLIJSONEncoder(json.JSONEncoder):
__api_types: The list of ArmoniK API Python objects managed by this encoder.
"""

__api_types = [Session, TaskOptions]
__api_types = [Session, TaskOptions, Task]

def default(self, obj: object) -> Union[str, Dict[str, Any]]:
"""
Expand Down
Loading

0 comments on commit bbdbb97

Please sign in to comment.