From 5dd512a164532d23d9ad4fb2323cfc4868739c61 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Wed, 13 Nov 2024 21:23:28 -0600 Subject: [PATCH] (shortfin-sd) readme / client simplifications (#504) --- .../python/shortfin/support/logging_setup.py | 2 +- shortfin/python/shortfin_apps/sd/README.md | 25 ++------ .../sd/components/config_artifacts.py | 12 ---- shortfin/python/shortfin_apps/sd/server.py | 2 +- .../python/shortfin_apps/sd/simple_client.py | 61 +++++++++---------- 5 files changed, 37 insertions(+), 65 deletions(-) diff --git a/shortfin/python/shortfin/support/logging_setup.py b/shortfin/python/shortfin/support/logging_setup.py index 3cb373f1e..849d65bf3 100644 --- a/shortfin/python/shortfin/support/logging_setup.py +++ b/shortfin/python/shortfin/support/logging_setup.py @@ -38,7 +38,7 @@ def __init__(self): native_handler.setFormatter(NativeFormatter()) # TODO: Source from env vars. -logger.setLevel(logging.DEBUG) +logger.setLevel(logging.WARNING) logger.addHandler(native_handler) diff --git a/shortfin/python/shortfin_apps/sd/README.md b/shortfin/python/shortfin_apps/sd/README.md index 6dd701c62..30002ec40 100644 --- a/shortfin/python/shortfin_apps/sd/README.md +++ b/shortfin/python/shortfin_apps/sd/README.md @@ -10,32 +10,19 @@ In your shortfin environment, pip install transformers pip install dataclasses-json pip install pillow +pip install shark-ai ``` ``` python -m shortfin_apps.sd.server --help ``` -## Run tests - - - From SHARK-Platform/shortfin: - ``` - pytest --system=amdgpu -k "sd" - ``` - The tests run with splat weights. - - -## Run on MI300x - - - Follow quick start +# Run on MI300x +The server will prepare runtime artifacts for you. - - Navigate to shortfin/ (only necessary if you're using following CLI exactly.) -``` -cd shortfin/ -``` - - Run CLI server interface (you can find `sdxl_config_i8.json` in shortfin_apps/sd/examples): +By default, the port is set to 8000. If you would like to change this, use `--port` in each of the following commands. -The server will prepare runtime artifacts for you. +You can check if this (or any) port is in use on Linux with `ss -ntl | grep 8000`. ``` python -m shortfin_apps.sd.server --device=amdgpu --device_ids=0 --build_preference=precompiled --topology="spx_single" @@ -43,5 +30,5 @@ python -m shortfin_apps.sd.server --device=amdgpu --device_ids=0 --build_prefere - Run a CLI client in a separate shell: ``` -python -m shortfin_apps.sd.simple_client --interactive --save +python -m shortfin_apps.sd.simple_client --interactive ``` diff --git a/shortfin/python/shortfin_apps/sd/components/config_artifacts.py b/shortfin/python/shortfin_apps/sd/components/config_artifacts.py index b5a1d682b..f3502f22e 100644 --- a/shortfin/python/shortfin_apps/sd/components/config_artifacts.py +++ b/shortfin/python/shortfin_apps/sd/components/config_artifacts.py @@ -11,18 +11,6 @@ import shortfin.array as sfnp import copy -from shortfin_apps.sd.components.config_struct import ModelParams - -this_dir = os.path.dirname(os.path.abspath(__file__)) -parent = os.path.dirname(this_dir) - -dtype_to_filetag = { - sfnp.float16: "fp16", - sfnp.float32: "fp32", - sfnp.int8: "i8", - sfnp.bfloat16: "bf16", -} - ARTIFACT_VERSION = "11132024" SDXL_CONFIG_BUCKET = f"https://sharkpublic.blob.core.windows.net/sharkpublic/sdxl/{ARTIFACT_VERSION}/configs/" diff --git a/shortfin/python/shortfin_apps/sd/server.py b/shortfin/python/shortfin_apps/sd/server.py index 9ee81d1c4..2b7a93a91 100644 --- a/shortfin/python/shortfin_apps/sd/server.py +++ b/shortfin/python/shortfin_apps/sd/server.py @@ -355,7 +355,7 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): parser.add_argument( "--tuning_spec", type=str, - default="", + default=None, help="Path to transform dialect spec if compiling an executable with tunings.", ) parser.add_argument( diff --git a/shortfin/python/shortfin_apps/sd/simple_client.py b/shortfin/python/shortfin_apps/sd/simple_client.py index f8aabd8e7..550fd7c60 100644 --- a/shortfin/python/shortfin_apps/sd/simple_client.py +++ b/shortfin/python/shortfin_apps/sd/simple_client.py @@ -32,7 +32,7 @@ } -def bytes_to_img(bytes, idx=0, width=1024, height=1024, outputdir="./gen_imgs"): +def bytes_to_img(bytes, outputdir, idx=0, width=1024, height=1024): timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") image = Image.frombytes( mode="RGB", size=(width, height), data=base64.b64decode(bytes) @@ -46,6 +46,7 @@ def bytes_to_img(bytes, idx=0, width=1024, height=1024, outputdir="./gen_imgs"): def get_batched(request, arg, idx): if isinstance(request[arg], list): + # some args are broadcasted to each prompt, hence overriding idx for single-item entries if len(request[arg]) == 1: indexed = request[arg][0] else: @@ -56,34 +57,30 @@ def get_batched(request, arg, idx): async def send_request(session, rep, args, data): - try: - print("Sending request batch #", rep) - url = f"http://0.0.0.0:{args.port}/generate" - start = time.time() - async with session.post(url, json=data) as response: - end = time.time() - # Check if the response was successful - if response.status == 200: - response.raise_for_status() # Raise an error for bad responses - timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") - res_json = await response.json(content_type=None) - if args.save: - for idx, item in enumerate(res_json["images"]): - width = get_batched(data, "width", idx) - height = get_batched(data, "height", idx) - print("Saving response as image...") - bytes_to_img( - item.encode("utf-8"), idx, width, height, args.outputdir - ) - latency = end - start - print("Responses processed.") - return latency, len(data["prompt"]) - else: - print(f"Error: Received {response.status} from server") - raise Exception - except Exception as e: - print(f"Request failed: {e}") - raise Exception + print("Sending request batch #", rep) + url = f"http://0.0.0.0:{args.port}/generate" + start = time.time() + async with session.post(url, json=data) as response: + end = time.time() + # Check if the response was successful + if response.status == 200: + response.raise_for_status() # Raise an error for bad responses + timestamp = dt.now().strftime("%Y-%m-%d_%H-%M-%S") + res_json = await response.json(content_type=None) + if args.save: + for idx, item in enumerate(res_json["images"]): + width = get_batched(data, "width", idx) + height = get_batched(data, "height", idx) + print("Saving response as image...") + bytes_to_img( + item.encode("utf-8"), args.outputdir, idx, width, height + ) + latency = end - start + print("Responses processed.") + return latency, len(data["prompt"]) + else: + print(f"Error: Received {response.status} from server") + raise Exception async def static(args): @@ -94,7 +91,7 @@ async def static(args): sample_counts = [] # Read the JSON file if supplied. Otherwise, get user input. try: - if args.file == "default": + if not args.file: data = sample_request else: with open(args.file, "r") as json_file: @@ -135,7 +132,7 @@ async def interactive(args): sample_counts = [] # Read the JSON file if supplied. Otherwise, get user input. try: - if args.file == "default": + if not args.file: data = sample_request else: with open(args.file, "r") as json_file: @@ -185,7 +182,7 @@ def main(argv): p.add_argument( "--file", type=str, - default="default", + default=None, help="A non-default request to send to the server.", ) p.add_argument(