diff --git a/src/armonik_cli/cli.py b/src/armonik_cli/cli.py index e0c2ef9..374b10d 100644 --- a/src/armonik_cli/cli.py +++ b/src/armonik_cli/cli.py @@ -13,3 +13,4 @@ def cli() -> None: cli.add_command(commands.sessions) +cli.add_command(commands.partitions) diff --git a/src/armonik_cli/commands/__init__.py b/src/armonik_cli/commands/__init__.py index d3bff87..5a3a020 100644 --- a/src/armonik_cli/commands/__init__.py +++ b/src/armonik_cli/commands/__init__.py @@ -1,4 +1,4 @@ from .sessions import sessions +from .partitions import partitions - -__all__ = ["sessions"] +__all__ = ["sessions", "partitions"] diff --git a/src/armonik_cli/commands/partitions.py b/src/armonik_cli/commands/partitions.py new file mode 100644 index 0000000..1120d05 --- /dev/null +++ b/src/armonik_cli/commands/partitions.py @@ -0,0 +1,84 @@ +import grpc +import rich_click as click + +from typing import Union + +from armonik.client.partitions import ArmoniKPartitions +from armonik.common.filter import PartitionFilter + +from armonik_cli.core import console, base_command +from armonik_cli.core.params import FilterParam + +PARTITIONS_TABLE_COLS = [("ID", "Id"), ("PodReserved", "PodReserved"), ("PodMax", "PodMax")] + + +@click.group(name="partition") +def partitions() -> None: + """Manage cluster partitions.""" + pass + + +@partitions.command() +@click.option( + "-f", + "--filter", + "filter_with", + type=FilterParam("Partition"), + required=False, + help="An expression to filter partitions with", + metavar="", +) +# @click.option( +# "--sort_by", type=FieldParam("Task"), required=False, help="Attribute of partition to sort with." +# ) +# @click.option( +# "--sort_order", +# type=click.Choice(["asc", "desc"], case_sensitive=False), +# default="asc", +# required=False, +# help="Whether to sort by ascending or by descending order.", +# ) +@click.option( + "--page", default=-1, help="Get a specific page, it defaults to -1 which gets all pages." +) +@click.option("--page-size", default=100, help="Number of elements in each page") +@base_command +def list( + endpoint: str, + output: str, + filter_with: Union[PartitionFilter, None], + page: int, + page_size: int, + debug: bool, +) -> None: + """List the partitions in an ArmoniK cluster.""" + with grpc.insecure_channel(endpoint) as channel: + partitions_client = ArmoniKPartitions(channel) + curr_page = page if page > 0 else 0 + partitions_list = [] + while True: + total, partitions = partitions_client.list_partitions( + partition_filter=filter_with, + page=curr_page, + page_size=page_size, + ) + partitions_list += partitions + if page > 0 or len(partitions_list) >= total: + break + curr_page += 1 + + if total > 0: + console.formatted_print( + partitions_list, format=output, table_cols=PARTITIONS_TABLE_COLS + ) + + +@partitions.command() +@click.argument("partition-id", type=str, required=True) +@base_command +def get(endpoint: str, output: str, partition_id: str, debug: bool) -> None: + """Get a specific partition from an ArmoniK cluster given a .""" + with grpc.insecure_channel(endpoint) as channel: + partitions_client = ArmoniKPartitions(channel) + partition = partitions_client.get_partition(partition_id) + console.formatted_print(partition, format=output, table_cols=PARTITIONS_TABLE_COLS) diff --git a/src/armonik_cli/core/console.py b/src/armonik_cli/core/console.py index 66f8043..ee45068 100644 --- a/src/armonik_cli/core/console.py +++ b/src/armonik_cli/core/console.py @@ -65,7 +65,7 @@ def _build_table(obj: Dict[str, Any], table_cols: List[Tuple[str, str]]) -> Tabl objs = obj if isinstance(obj, List) else [obj] for item in objs: - table.add_row(*[item[key] for _, key in table_cols]) + table.add_row(*[str(item[key]) for _, key in table_cols]) return table diff --git a/src/armonik_cli/core/serialize.py b/src/armonik_cli/core/serialize.py index 861bc4c..cb7bf79 100644 --- a/src/armonik_cli/core/serialize.py +++ b/src/armonik_cli/core/serialize.py @@ -1,10 +1,10 @@ import json from datetime import datetime, timedelta -from typing import Dict, Union, Any +from typing import Dict, List, Union, Any -from armonik.common import Session, TaskOptions -from google._upb._message import ScalarMapContainer +from armonik.common import Session, TaskOptions, Partition +from google._upb._message import ScalarMapContainer, RepeatedScalarContainer class CLIJSONEncoder(json.JSONEncoder): @@ -16,9 +16,9 @@ class CLIJSONEncoder(json.JSONEncoder): __api_types: The list of ArmoniK API Python objects managed by this encoder. """ - __api_types = [Session, TaskOptions] + __api_types = [Session, TaskOptions, Partition] - def default(self, obj: object) -> Union[str, Dict[str, Any]]: + def default(self, obj: object) -> Union[str, Dict[str, Any], List[Any]]: """ Override the `default` method to serialize non-serializable objects to JSON. @@ -36,6 +36,8 @@ def default(self, obj: object) -> Union[str, Dict[str, Any]]: # serializing the associated gRPC object. elif isinstance(obj, ScalarMapContainer): return json.loads(str(obj).replace("'", '"')) + elif isinstance(obj, RepeatedScalarContainer): + return list(obj) elif any([isinstance(obj, api_type) for api_type in self.__api_types]): return {self.camel_case(k): v for k, v in obj.__dict__.items()} else: diff --git a/tests/commands/test_partitions.py b/tests/commands/test_partitions.py new file mode 100644 index 0000000..ac0c58e --- /dev/null +++ b/tests/commands/test_partitions.py @@ -0,0 +1,50 @@ +from copy import deepcopy +import pytest + +from armonik.client import ArmoniKPartitions +from armonik.common import Partition + +from conftest import run_cmd_and_assert_exit_code, reformat_cmd_output + +ENDPOINT = "172.17.119.85:5001" + +raw_partition = Partition( + id="stream", + parent_partition_ids=[], + pod_reserved=1, + pod_max=100, + pod_configuration={}, + preemption_percentage=50, + priority=1, +) + +serialized_partition = { + "Id": "stream", + "ParentPartitionIds": [], + "PodReserved": 1, + "PodMax": 100, + "PodConfiguration": {}, + "PreemptionPercentage": 50, + "Priority": 1, +} + + +@pytest.mark.parametrize("cmd", [f"partition list -e {ENDPOINT} --debug"]) +def test_partition_list(mocker, cmd): + mocker.patch.object( + ArmoniKPartitions, "list_partitions", return_value=(1, [deepcopy(raw_partition)]) + ) + result = run_cmd_and_assert_exit_code(cmd) + assert reformat_cmd_output(result.output, deserialize=True) == [serialized_partition] + + +@pytest.mark.parametrize( + "cmd", + [ + f"partition get --endpoint {ENDPOINT} {serialized_partition['Id']}", + ], +) +def test_partition_get(mocker, cmd): + mocker.patch.object(ArmoniKPartitions, "get_partition", return_value=raw_partition) + result = run_cmd_and_assert_exit_code(cmd) + assert reformat_cmd_output(result.output, deserialize=True) == serialized_partition