Skip to content

Commit

Permalink
Added partition commands
Browse files Browse the repository at this point in the history
  • Loading branch information
AncientPatata committed Jan 15, 2025
1 parent a32b4dc commit 733747a
Show file tree
Hide file tree
Showing 6 changed files with 197 additions and 7 deletions.
1 change: 1 addition & 0 deletions src/armonik_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ def cli() -> None:

cli.add_command(commands.sessions)
cli.add_command(commands.tasks)
cli.add_command(commands.partitions)
4 changes: 3 additions & 1 deletion src/armonik_cli/commands/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .sessions import sessions
from .tasks import tasks
from .partitions import partitions

__all__ = ["sessions", "tasks"]

__all__ = ["sessions", "tasks", "partitions"]
97 changes: 97 additions & 0 deletions src/armonik_cli/commands/partitions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import grpc
import rich_click as click

from typing import List, Union

from armonik.client.partitions import ArmoniKPartitions
from armonik.common.filter import Filter, PartitionFilter
from armonik.common import Partition, Direction

from armonik_cli.core import console, base_command
from armonik_cli.core.params import FilterParam, FieldParam

PARTITIONS_TABLE_COLS = [("ID", "Id"), ("PodReserved", "PodReserved"), ("PodMax", "PodMax")]


@click.group(name="partition")
def partitions() -> None:
"""Manage cluster partitions."""
pass


@partitions.command()
@click.option(
"-f",
"--filter",
"filter_with",
type=FilterParam("Partition"),
required=False,
help="An expression to filter partitions with",
metavar="FILTER EXPR",
)
@click.option(
"--sort-by",
type=FieldParam("Task"),
required=False,
help="Attribute of partition to sort with.",
)
@click.option(
"--sort-direction",
type=click.Choice(["asc", "desc"], case_sensitive=False),
default="asc",
required=False,
help="Whether to sort by ascending or by descending order.",
)
@click.option(
"--page", default=-1, help="Get a specific page, it defaults to -1 which gets all pages."
)
@click.option("--page-size", default=100, help="Number of elements in each page")
@base_command
def list(
endpoint: str,
output: str,
filter_with: Union[PartitionFilter, None],
sort_by: Filter,
sort_direction: str,
page: int,
page_size: int,
debug: bool,
) -> None:
"""List the partitions in an ArmoniK cluster."""
with grpc.insecure_channel(endpoint) as channel:
partitions_client = ArmoniKPartitions(channel)
curr_page = page if page > 0 else 0
partitions_list = []
while True:
total, partitions = partitions_client.list_partitions(
partition_filter=filter_with,
sort_field=Partition.id if sort_by is None else sort_by,
sort_direction=Direction.ASC
if sort_direction.capitalize() == "ASC"
else Direction.DESC,
page=curr_page,
page_size=page_size,
)
partitions_list += partitions
if page > 0 or len(partitions_list) >= total:
break
curr_page += 1

if total > 0:
console.formatted_print(
partitions_list, format=output, table_cols=PARTITIONS_TABLE_COLS
)


@partitions.command()
@click.argument("partition-ids", type=str, nargs=-1, required=True)
@base_command
def get(endpoint: str, output: str, partition_ids: List[str], debug: bool) -> None:
"""Get a specific partition from an ArmoniK cluster given a <PARTITION-ID>."""
with grpc.insecure_channel(endpoint) as channel:
partitions_client = ArmoniKPartitions(channel)
partitions = []
for partition_id in partition_ids:
partition = partitions_client.get_partition(partition_id)
partitions.append(partition)
console.formatted_print(partitions, format=output, table_cols=PARTITIONS_TABLE_COLS)
2 changes: 1 addition & 1 deletion src/armonik_cli/core/console.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def _build_table(obj: Dict[str, Any], table_cols: List[Tuple[str, str]]) -> Tabl

objs = obj if isinstance(obj, List) else [obj]
for item in objs:
table.add_row(*[item[key] for _, key in table_cols])
table.add_row(*[str(item[key]) for _, key in table_cols])

return table

Expand Down
12 changes: 7 additions & 5 deletions src/armonik_cli/core/serialize.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import json

from datetime import datetime, timedelta
from typing import Dict, Union, Any
from typing import Dict, List, Union, Any

from armonik.common import Session, TaskOptions, Task
from google._upb._message import ScalarMapContainer
from armonik.common import Session, TaskOptions, Task, Partition
from google._upb._message import ScalarMapContainer, RepeatedScalarContainer


class CLIJSONEncoder(json.JSONEncoder):
Expand All @@ -16,9 +16,9 @@ class CLIJSONEncoder(json.JSONEncoder):
__api_types: The list of ArmoniK API Python objects managed by this encoder.
"""

__api_types = [Session, TaskOptions, Task]
__api_types = [Session, TaskOptions, Task, Partition]

def default(self, obj: object) -> Union[str, Dict[str, Any]]:
def default(self, obj: object) -> Union[str, Dict[str, Any], List[Any]]:
"""
Override the `default` method to serialize non-serializable objects to JSON.
Expand All @@ -36,6 +36,8 @@ def default(self, obj: object) -> Union[str, Dict[str, Any]]:
# serializing the associated gRPC object.
elif isinstance(obj, ScalarMapContainer):
return json.loads(str(obj).replace("'", '"'))
elif isinstance(obj, RepeatedScalarContainer):
return list(obj)
elif any([isinstance(obj, api_type) for api_type in self.__api_types]):
return {self.camel_case(k): v for k, v in obj.__dict__.items()}
else:
Expand Down
88 changes: 88 additions & 0 deletions tests/commands/test_partitions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from copy import deepcopy
import pytest

from armonik.client import ArmoniKPartitions
from armonik.common import Partition

from conftest import run_cmd_and_assert_exit_code, reformat_cmd_output

ENDPOINT = "172.17.119.85:5001"


raw_partitions = [
Partition(
id="stream",
parent_partition_ids=[],
pod_reserved=1,
pod_max=100,
pod_configuration={},
preemption_percentage=50,
priority=1,
),
Partition(
id="bench",
parent_partition_ids=[],
pod_reserved=1,
pod_max=100,
pod_configuration={},
preemption_percentage=50,
priority=1,
),
]

serialized_partitions = [
{
"Id": "stream",
"ParentPartitionIds": [],
"PodReserved": 1,
"PodMax": 100,
"PodConfiguration": {},
"PreemptionPercentage": 50,
"Priority": 1,
},
{
"Id": "bench",
"ParentPartitionIds": [],
"PodReserved": 1,
"PodMax": 100,
"PodConfiguration": {},
"PreemptionPercentage": 50,
"Priority": 1,
},
]


@pytest.mark.parametrize("cmd", [f"partition list -e {ENDPOINT} --debug"])
def test_partition_list(mocker, cmd):
mocker.patch.object(
ArmoniKPartitions,
"list_partitions",
return_value=(len(raw_partitions), deepcopy(raw_partitions)),
)
result = run_cmd_and_assert_exit_code(cmd)
assert reformat_cmd_output(result.output, deserialize=True) == serialized_partitions


@pytest.mark.parametrize(
"cmd, expected_output",
[
(
f"partition get --endpoint {ENDPOINT} {serialized_partitions[0]['Id']}",
[serialized_partitions[0]],
),
(
f"partition get --endpoint {ENDPOINT} {serialized_partitions[0]['Id']} {serialized_partitions[1]['Id']}",
[serialized_partitions[0], serialized_partitions[1]],
),
],
)
def test_partition_get(mocker, cmd, expected_output):
def get_partitions_side_effect(partition_id):
if partition_id == serialized_partitions[0]["Id"]:
return deepcopy(raw_partitions[0])
elif partition_id == serialized_partitions[1]["Id"]:
return deepcopy(raw_partitions[1])

mocker.patch.object(ArmoniKPartitions, "get_partition", side_effect=get_partitions_side_effect)
result = run_cmd_and_assert_exit_code(cmd)
assert reformat_cmd_output(result.output, deserialize=True) == expected_output

0 comments on commit 733747a

Please sign in to comment.