diff --git a/src/armonik_cli/cli.py b/src/armonik_cli/cli.py index e0c2ef9..374b10d 100644 --- a/src/armonik_cli/cli.py +++ b/src/armonik_cli/cli.py @@ -13,3 +13,4 @@ def cli() -> None: cli.add_command(commands.sessions) +cli.add_command(commands.partitions) diff --git a/src/armonik_cli/commands/__init__.py b/src/armonik_cli/commands/__init__.py index d3bff87..5a3a020 100644 --- a/src/armonik_cli/commands/__init__.py +++ b/src/armonik_cli/commands/__init__.py @@ -1,4 +1,4 @@ from .sessions import sessions +from .partitions import partitions - -__all__ = ["sessions"] +__all__ = ["sessions", "partitions"] diff --git a/src/armonik_cli/commands/partitions.py b/src/armonik_cli/commands/partitions.py new file mode 100644 index 0000000..0f405c0 --- /dev/null +++ b/src/armonik_cli/commands/partitions.py @@ -0,0 +1,94 @@ +import grpc +import rich_click as click + +from typing import List, Union + +from armonik.client.partitions import ArmoniKPartitions +from armonik.common.filter import Filter, PartitionFilter +from armonik.common import Partition, Direction + +from armonik_cli.core import console, base_command +from armonik_cli.core.params import FilterParam,FieldParam + +PARTITIONS_TABLE_COLS = [("ID", "Id"), ("PodReserved", "PodReserved"), ("PodMax", "PodMax")] + + +@click.group(name="partition") +def partitions() -> None: + """Manage cluster partitions.""" + pass + + +@partitions.command() +@click.option( + "-f", + "--filter", + "filter_with", + type=FilterParam("Partition"), + required=False, + help="An expression to filter partitions with", + metavar="FILTER EXPR", +) +@click.option( + "--sort-by", type=FieldParam("Task"), required=False, help="Attribute of partition to sort with." +) +@click.option( + "--sort-order", + type=click.Choice(["asc", "desc"], case_sensitive=False), + default="asc", + required=False, + help="Whether to sort by ascending or by descending order.", +) +@click.option( + "--page", default=-1, help="Get a specific page, it defaults to -1 which gets all pages." +) +@click.option("--page-size", default=100, help="Number of elements in each page") +@base_command +def list( + endpoint: str, + output: str, + filter_with: Union[PartitionFilter, None], + sort_by: Filter, + sort_direction: str, + page: int, + page_size: int, + debug: bool, +) -> None: + """List the partitions in an ArmoniK cluster.""" + with grpc.insecure_channel(endpoint) as channel: + partitions_client = ArmoniKPartitions(channel) + curr_page = page if page > 0 else 0 + partitions_list = [] + while True: + total, partitions = partitions_client.list_partitions( + partition_filter=filter_with, + sort_field=Partition.id if sort_by is None else sort_by, + sort_direction=Direction.ASC + if sort_direction.capitalize() == "ASC" + else Direction.DESC, + page=curr_page, + page_size=page_size, + ) + partitions_list += partitions + if page > 0 or len(partitions_list) >= total: + break + curr_page += 1 + + if total > 0: + console.formatted_print( + partitions_list, format=output, table_cols=PARTITIONS_TABLE_COLS + ) + + +@partitions.command() +@click.argument("partition-ids", type=str,nargs=-1, required=True) +@base_command +def get(endpoint: str, output: str, partition_ids: List[str], debug: bool) -> None: + """Get a specific partition from an ArmoniK cluster given a .""" + with grpc.insecure_channel(endpoint) as channel: + partitions_client = ArmoniKPartitions(channel) + partitions = [] + for partition_id in partition_ids: + partition = partitions_client.get_partition(partition_id) + partitions.append(partition) + console.formatted_print(partitions, format=output, table_cols=PARTITIONS_TABLE_COLS) diff --git a/src/armonik_cli/core/console.py b/src/armonik_cli/core/console.py index 66f8043..ee45068 100644 --- a/src/armonik_cli/core/console.py +++ b/src/armonik_cli/core/console.py @@ -65,7 +65,7 @@ def _build_table(obj: Dict[str, Any], table_cols: List[Tuple[str, str]]) -> Tabl objs = obj if isinstance(obj, List) else [obj] for item in objs: - table.add_row(*[item[key] for _, key in table_cols]) + table.add_row(*[str(item[key]) for _, key in table_cols]) return table diff --git a/src/armonik_cli/core/serialize.py b/src/armonik_cli/core/serialize.py index 861bc4c..cb7bf79 100644 --- a/src/armonik_cli/core/serialize.py +++ b/src/armonik_cli/core/serialize.py @@ -1,10 +1,10 @@ import json from datetime import datetime, timedelta -from typing import Dict, Union, Any +from typing import Dict, List, Union, Any -from armonik.common import Session, TaskOptions -from google._upb._message import ScalarMapContainer +from armonik.common import Session, TaskOptions, Partition +from google._upb._message import ScalarMapContainer, RepeatedScalarContainer class CLIJSONEncoder(json.JSONEncoder): @@ -16,9 +16,9 @@ class CLIJSONEncoder(json.JSONEncoder): __api_types: The list of ArmoniK API Python objects managed by this encoder. """ - __api_types = [Session, TaskOptions] + __api_types = [Session, TaskOptions, Partition] - def default(self, obj: object) -> Union[str, Dict[str, Any]]: + def default(self, obj: object) -> Union[str, Dict[str, Any], List[Any]]: """ Override the `default` method to serialize non-serializable objects to JSON. @@ -36,6 +36,8 @@ def default(self, obj: object) -> Union[str, Dict[str, Any]]: # serializing the associated gRPC object. elif isinstance(obj, ScalarMapContainer): return json.loads(str(obj).replace("'", '"')) + elif isinstance(obj, RepeatedScalarContainer): + return list(obj) elif any([isinstance(obj, api_type) for api_type in self.__api_types]): return {self.camel_case(k): v for k, v in obj.__dict__.items()} else: diff --git a/tests/commands/test_partitions.py b/tests/commands/test_partitions.py new file mode 100644 index 0000000..c181f5f --- /dev/null +++ b/tests/commands/test_partitions.py @@ -0,0 +1,81 @@ +from copy import deepcopy +import pytest + +from armonik.client import ArmoniKPartitions +from armonik.common import Partition + +from conftest import run_cmd_and_assert_exit_code, reformat_cmd_output + +ENDPOINT = "172.17.119.85:5001" + + +raw_partitions = [ + Partition( + id="stream", + parent_partition_ids=[], + pod_reserved=1, + pod_max=100, + pod_configuration={}, + preemption_percentage=50, + priority=1, + ), + Partition( + id="bench", + parent_partition_ids=[], + pod_reserved=1, + pod_max=100, + pod_configuration={}, + preemption_percentage=50, + priority=1, + ) + +] + +serialized_partitions = [ + { + "Id": "stream", + "ParentPartitionIds": [], + "PodReserved": 1, + "PodMax": 100, + "PodConfiguration": {}, + "PreemptionPercentage": 50, + "Priority": 1, +}, +{ + "Id": "bench", + "ParentPartitionIds": [], + "PodReserved": 1, + "PodMax": 100, + "PodConfiguration": {}, + "PreemptionPercentage": 50, + "Priority": 1, +} +] + + +@pytest.mark.parametrize("cmd", [f"partition list -e {ENDPOINT} --debug"]) +def test_partition_list(mocker, cmd): + mocker.patch.object( + ArmoniKPartitions, "list_partitions", return_value=(len(raw_partitions), deepcopy(raw_partitions)) + ) + result = run_cmd_and_assert_exit_code(cmd) + assert reformat_cmd_output(result.output, deserialize=True) == serialized_partitions + + +@pytest.mark.parametrize( + "cmd, expected_output", + [ + (f"partition get --endpoint {ENDPOINT} {serialized_partitions[0]['Id']}", serialized_partitions[0]), + (f"partition get --endpoint {ENDPOINT} {serialized_partitions[0]['Id']} {serialized_partitions[1]['Id']}", [serialized_partitions[0], serialized_partitions[1]]) + + ], +) +def test_partition_get(mocker, cmd, expected_output): + def get_partitions_side_effect(partition_id): + if partition_id == serialized_partitions[0]['Id']: + return deepcopy(raw_partitions[0]) + elif partition_id == serialized_partitions[0]['Id']: + return deepcopy(raw_partitions[1]) + mocker.patch.object(ArmoniKPartitions, "get_partition", side_effect=get_partitions_side_effect) + result = run_cmd_and_assert_exit_code(cmd) + assert reformat_cmd_output(result.output, deserialize=True) == expected_output