From c98029b891169a30734581175b39fa7f8b7e33e5 Mon Sep 17 00:00:00 2001 From: Cedar Date: Tue, 28 Jan 2025 15:48:57 -0800 Subject: [PATCH 1/9] refacotring a bunch of things --- .../llm/server_management.py | 23 +--- .../llm/components/config_struct.py | 2 - .../llm/components/server_config.py | 104 ++++++++++++++++++ .../shortfin_apps/llm/components/service.py | 9 +- shortfin/python/shortfin_apps/llm/server.py | 26 ++++- 5 files changed, 134 insertions(+), 30 deletions(-) create mode 100644 shortfin/python/shortfin_apps/llm/components/server_config.py diff --git a/app_tests/integration_tests/llm/server_management.py b/app_tests/integration_tests/llm/server_management.py index 36c7b31b2..cd61f0cf6 100644 --- a/app_tests/integration_tests/llm/server_management.py +++ b/app_tests/integration_tests/llm/server_management.py @@ -61,31 +61,11 @@ def find_available_port() -> int: s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) return s.getsockname()[1] - def _write_config(self) -> Path: - """Creates server config by extending the exported model config.""" - # TODO: eliminate this by moving prefix sharing algorithm to be a cmdline arg of server.py - source_config_path = self.config.artifacts.config_path - server_config_path = ( - source_config_path.parent - / f"server_config_{self.config.prefix_sharing_algorithm}.json" - ) - - # Read the exported config as base - with open(source_config_path) as f: - config = json.load(f) - config["paged_kv_cache"][ - "prefix_sharing_algorithm" - ] = self.config.prefix_sharing_algorithm - with open(server_config_path, "w") as f: - json.dump(config, f) - return server_config_path - def start(self) -> None: """Starts the server process.""" if self.process is not None: raise RuntimeError("Server is already running") - self.config_path = self._write_config() self.port = self.find_available_port() cmd = [ @@ -93,10 +73,11 @@ def start(self) -> None: "-m", "shortfin_apps.llm.server", f"--tokenizer_json={self.config.artifacts.tokenizer_path}", - f"--model_config={self.config_path}", + f"--model_config={self.config.artifacts.config_path}", f"--vmfb={self.config.artifacts.vmfb_path}", f"--parameters={self.config.artifacts.weights_path}", f"--port={self.port}", + f"--prefix_sharing_algorithm={self.config.prefix_sharing_algorithm}", ] cmd.extend(self.config.device_settings.server_flags) diff --git a/shortfin/python/shortfin_apps/llm/components/config_struct.py b/shortfin/python/shortfin_apps/llm/components/config_struct.py index 8fefa0a12..8ff826d38 100644 --- a/shortfin/python/shortfin_apps/llm/components/config_struct.py +++ b/shortfin/python/shortfin_apps/llm/components/config_struct.py @@ -91,8 +91,6 @@ class PagedKVCacheParams: # Default: 256 device_block_count: int - prefix_sharing_algorithm: str = "none" # currently supporting none and trie - @dataclass_json(undefined=Undefined.RAISE) @dataclass diff --git a/shortfin/python/shortfin_apps/llm/components/server_config.py b/shortfin/python/shortfin_apps/llm/components/server_config.py new file mode 100644 index 000000000..b0bee1b1b --- /dev/null +++ b/shortfin/python/shortfin_apps/llm/components/server_config.py @@ -0,0 +1,104 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +"""Server configuration management.""" + +import json +import os +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional + +from dataclasses_json import dataclass_json, Undefined + + +@dataclass_json(undefined=Undefined.RAISE) +@dataclass +class ServerParams: + """Server configuration parameters.""" + + # KV cache configuration + prefix_sharing_algorithm: str = "none" # none or trie + + # Server runtime configuration + host: Optional[str] = None + port: int = 8000 + root_path: Optional[str] = None + timeout_keep_alive: int = 5 + + # Program isolation configuration + program_isolation: str = "per_call" + + # Device configuration + device_ids: list[str] = field(default_factory=list) + amdgpu_async_allocations: bool = False + amdgpu_allocators: Optional[str] = None + + @staticmethod + def load(config_path: Optional[Path] = None) -> "ServerParams": + """Load server configuration from a file or use defaults. + + Args: + config_path: Path to config file. If None, will check standard locations. + + Returns: + ServerParams instance with defaults or loaded values + """ + if config_path is None: + # Check standard locations + config_paths = [ + Path.home() / ".shortfin" / "server_config.json", + Path.home() / ".config" / "shortfin" / "server_config.json", + Path("/etc/shortfin/server_config.json"), + ] + + for path in config_paths: + if path.exists(): + config_path = path + break + + # Start with defaults + params = ServerParams() + + # Override with config file if it exists + if config_path and config_path.exists(): + with open(config_path) as f: + file_params = ServerParams.from_json(f.read()) + # Update only non-None values from file + for field in params.__dataclass_fields__: + file_value = getattr(file_params, field) + if file_value is not None: + setattr(params, field, file_value) + + return params + + def update_from_args(self, args) -> None: + """Update configuration from command line arguments. + + Args: + args: Parsed command line arguments + + Command line arguments take highest priority. + """ + # Only update fields that are present in args + for field in self.__dataclass_fields__: + if hasattr(args, field): + arg_value = getattr(args, field) + if arg_value is not None: # Only override if arg was provided + setattr(self, field, arg_value) + + def save(self, config_path: Optional[Path] = None): + """Save configuration to a file. + + Args: + config_path: Path to save to. If None, saves to ~/.shortfin/server_config.json + """ + if config_path is None: + config_path = Path.home() / ".shortfin" / "server_config.json" + + config_path.parent.mkdir(parents=True, exist_ok=True) + with open(config_path, "w") as f: + json.dump(self.to_dict(), f, indent=2) diff --git a/shortfin/python/shortfin_apps/llm/components/service.py b/shortfin/python/shortfin_apps/llm/components/service.py index 75dc47146..59c74f21e 100644 --- a/shortfin/python/shortfin_apps/llm/components/service.py +++ b/shortfin/python/shortfin_apps/llm/components/service.py @@ -21,6 +21,7 @@ from .config_struct import ModelParams from .manager import SystemManager from .messages import InferenceExecRequest, InferencePhase, StrobeMessage +from .server_config import ServerParams from .tokenizer import Tokenizer logger = logging.getLogger(__name__) @@ -47,6 +48,7 @@ def __init__( sysman: SystemManager, tokenizer: Tokenizer, model_params: ModelParams, + server_params: "ServerParams", program_isolation: str = "per_call", ): self.name = name @@ -55,6 +57,7 @@ def __init__( self.sysman = sysman self.tokenizer = tokenizer self.model_params = model_params + self.server_params = server_params self.inference_parameters: list[sf.BaseProgramParameters] = [] self.inference_modules: list[sf.ProgramModule] = [] @@ -71,19 +74,19 @@ def __init__( page_pool = PagePool( devices=self.main_fiber.devices_dict.values(), config=page_pool_config ) - if model_params.paged_kv_cache.prefix_sharing_algorithm == "trie": + if server_params.prefix_sharing_algorithm == "trie": self.page_cache = TriePagedAttentionCache( page_pool=page_pool, tokens_per_page=model_params.paged_kv_cache.block_seq_stride, ) - elif model_params.paged_kv_cache.prefix_sharing_algorithm == "none": + elif server_params.prefix_sharing_algorithm == "none": self.page_cache = BasePagedAttentionCache( page_pool=page_pool, tokens_per_page=model_params.paged_kv_cache.block_seq_stride, ) else: raise ValueError( - f"Unknown model_params.paged_kv_cache.prefix_sharing_algorithm {model_params.paged_kv_cache.prefix_sharing_algorithm}. Currently only supporting 'trie' and 'none'." + f"Unknown prefix_sharing_algorithm {server_params.prefix_sharing_algorithm}. Currently only supporting 'trie' and 'none'." ) self.program_isolation = PROG_ISOLATIONS[program_isolation] diff --git a/shortfin/python/shortfin_apps/llm/server.py b/shortfin/python/shortfin_apps/llm/server.py index 364321554..0a2686dc6 100644 --- a/shortfin/python/shortfin_apps/llm/server.py +++ b/shortfin/python/shortfin_apps/llm/server.py @@ -112,12 +112,18 @@ def get_eos_from_tokenizer_config(json_path): def configure(args) -> SystemManager: + # Load server configuration with priority: command line > config file > defaults + server_params = ServerParams.load( + args.server_config if hasattr(args, "server_config") else None + ) + server_params.update_from_args(args) + # Setup system (configure devices, etc). sysman = SystemManager( device=args.device, - device_ids=args.device_ids, - async_allocs=args.amdgpu_async_allocations, - amdgpu_allocators=args.amdgpu_allocators, + device_ids=server_params.device_ids, + async_allocs=server_params.amdgpu_async_allocations, + amdgpu_allocators=server_params.amdgpu_allocators, ) # Setup each service we are hosting. @@ -131,7 +137,8 @@ def configure(args) -> SystemManager: sysman=sysman, tokenizer=tokenizer, model_params=model_params, - program_isolation=args.isolation, + server_params=server_params, + program_isolation=server_params.program_isolation, ) sm.load_inference_module(args.vmfb) sm.load_inference_parameters(*args.parameters, parameter_scope="model") @@ -215,6 +222,17 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG): default=None, help="Allocator to use during VMFB invocation.", ) + parser.add_argument( + "--server_config", + type=Path, + help="Path to server configuration file", + ) + parser.add_argument( + "--prefix_sharing_algorithm", + type=str, + choices=["none", "trie"], + help="Algorithm to use for prefix sharing in KV cache", + ) args = parser.parse_args(argv) if args.tokenizer_config_json is None: From 7aba70202e9c86296e92d436c6c3d9ac536cad97 Mon Sep 17 00:00:00 2001 From: Cedar Date: Tue, 28 Jan 2025 16:25:38 -0800 Subject: [PATCH 2/9] missing import --- shortfin/python/shortfin_apps/llm/server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/shortfin/python/shortfin_apps/llm/server.py b/shortfin/python/shortfin_apps/llm/server.py index 0a2686dc6..7fe0e7605 100644 --- a/shortfin/python/shortfin_apps/llm/server.py +++ b/shortfin/python/shortfin_apps/llm/server.py @@ -28,6 +28,7 @@ from .components.io_struct import GenerateReqInput from .components.manager import SystemManager from .components.service import GenerateService +from .components.server_config import ServerParams from .components.tokenizer import Tokenizer From 16862a608acba4082207f2e85f6edf2bd75166ff Mon Sep 17 00:00:00 2001 From: Cedar Date: Wed, 29 Jan 2025 09:44:16 -0800 Subject: [PATCH 3/9] remove residual arg --- app_tests/integration_tests/llm/server_management.py | 1 - 1 file changed, 1 deletion(-) diff --git a/app_tests/integration_tests/llm/server_management.py b/app_tests/integration_tests/llm/server_management.py index cd61f0cf6..e34627a9f 100644 --- a/app_tests/integration_tests/llm/server_management.py +++ b/app_tests/integration_tests/llm/server_management.py @@ -51,7 +51,6 @@ def __init__(self, config: ServerConfig): self.config = config self.process: Optional[subprocess.Popen] = None self.port: Optional[int] = None - self.config_path: Optional[Path] = None @staticmethod def find_available_port() -> int: From 3dd57607542cda2558a7f2b13057401bb2987f78 Mon Sep 17 00:00:00 2001 From: Cedar Date: Wed, 29 Jan 2025 09:46:25 -0800 Subject: [PATCH 4/9] linting changes --- app_tests/integration_tests/llm/server_management.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/app_tests/integration_tests/llm/server_management.py b/app_tests/integration_tests/llm/server_management.py index e34627a9f..af7aa8b43 100644 --- a/app_tests/integration_tests/llm/server_management.py +++ b/app_tests/integration_tests/llm/server_management.py @@ -1,12 +1,10 @@ """Handles server lifecycle and configuration.""" -import json import socket from contextlib import closing from dataclasses import dataclass import subprocess import time import requests -from pathlib import Path import sys from typing import Optional From fd2c50d13ddcf701ce6781878719da2fbd33fb87 Mon Sep 17 00:00:00 2001 From: Cedar Date: Wed, 29 Jan 2025 09:56:26 -0800 Subject: [PATCH 5/9] move config structs together --- .../llm/components/{server_config.py => config_structs.py} | 0 shortfin/python/shortfin_apps/llm/components/service.py | 2 +- shortfin/python/shortfin_apps/llm/server.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename shortfin/python/shortfin_apps/llm/components/{server_config.py => config_structs.py} (100%) diff --git a/shortfin/python/shortfin_apps/llm/components/server_config.py b/shortfin/python/shortfin_apps/llm/components/config_structs.py similarity index 100% rename from shortfin/python/shortfin_apps/llm/components/server_config.py rename to shortfin/python/shortfin_apps/llm/components/config_structs.py diff --git a/shortfin/python/shortfin_apps/llm/components/service.py b/shortfin/python/shortfin_apps/llm/components/service.py index 59c74f21e..79a426aae 100644 --- a/shortfin/python/shortfin_apps/llm/components/service.py +++ b/shortfin/python/shortfin_apps/llm/components/service.py @@ -21,7 +21,7 @@ from .config_struct import ModelParams from .manager import SystemManager from .messages import InferenceExecRequest, InferencePhase, StrobeMessage -from .server_config import ServerParams +from .config_structs import ServerParams from .tokenizer import Tokenizer logger = logging.getLogger(__name__) diff --git a/shortfin/python/shortfin_apps/llm/server.py b/shortfin/python/shortfin_apps/llm/server.py index 7fe0e7605..b98f04fdc 100644 --- a/shortfin/python/shortfin_apps/llm/server.py +++ b/shortfin/python/shortfin_apps/llm/server.py @@ -28,7 +28,7 @@ from .components.io_struct import GenerateReqInput from .components.manager import SystemManager from .components.service import GenerateService -from .components.server_config import ServerParams +from .components.config_structs import ServerParams from .components.tokenizer import Tokenizer From 2e2911b21d897ced27a336bb8044183c884a490c Mon Sep 17 00:00:00 2001 From: Cedar Date: Wed, 29 Jan 2025 10:04:08 -0800 Subject: [PATCH 6/9] reorganize ServerParams to merge into config_struct.py --- .../llm/components/config_struct.py | 94 +++++++++++++++- .../llm/components/config_structs.py | 104 ------------------ .../shortfin_apps/llm/components/service.py | 3 +- shortfin/python/shortfin_apps/llm/server.py | 3 +- 4 files changed, 95 insertions(+), 109 deletions(-) delete mode 100644 shortfin/python/shortfin_apps/llm/components/config_structs.py diff --git a/shortfin/python/shortfin_apps/llm/components/config_struct.py b/shortfin/python/shortfin_apps/llm/components/config_struct.py index 8ff826d38..8415edc88 100644 --- a/shortfin/python/shortfin_apps/llm/components/config_struct.py +++ b/shortfin/python/shortfin_apps/llm/components/config_struct.py @@ -56,8 +56,11 @@ sense of scale only: real workloads will vary. """ -from dataclasses import dataclass +import json +import os +from dataclasses import dataclass, field from pathlib import Path +from typing import Optional import dataclasses_json from dataclasses_json import dataclass_json, Undefined @@ -192,3 +195,92 @@ def human_size(num, suffix="B"): return f"{num:3.1f}{unit}{suffix}" num /= 1024.0 return f"{num:.1f}Yi{suffix}" + + +@dataclass_json(undefined=Undefined.RAISE) +@dataclass +class ServerParams: + """Server configuration parameters.""" + + # KV cache configuration + prefix_sharing_algorithm: str = "none" # none or trie + + # Server runtime configuration + host: Optional[str] = None + port: int = 8000 + root_path: Optional[str] = None + timeout_keep_alive: int = 5 + + # Program isolation configuration + program_isolation: str = "per_call" + + # Device configuration + device_ids: list[str] = field(default_factory=list) + amdgpu_async_allocations: bool = False + amdgpu_allocators: Optional[str] = None + + @staticmethod + def load(config_path: Optional[Path] = None) -> "ServerParams": + """Load server configuration from a file or use defaults. + + Args: + config_path: Path to config file. If None, will check standard locations. + + Returns: + ServerParams instance with defaults or loaded values + """ + if config_path is None: + # Check standard locations + config_paths = [ + Path.home() / ".shortfin" / "server_config.json", + Path.home() / ".config" / "shortfin" / "server_config.json", + Path("/etc/shortfin/server_config.json"), + ] + + for path in config_paths: + if path.exists(): + config_path = path + break + + # Start with defaults + params = ServerParams() + + # Override with config file if it exists + if config_path and config_path.exists(): + with open(config_path) as f: + file_params = ServerParams.from_json(f.read()) + # Update only non-None values from file + for field in params.__dataclass_fields__: + file_value = getattr(file_params, field) + if file_value is not None: + setattr(params, field, file_value) + + return params + + def update_from_args(self, args) -> None: + """Update configuration from command line arguments. + + Args: + args: Parsed command line arguments + + Command line arguments take highest priority. + """ + # Only update fields that are present in args + for field in self.__dataclass_fields__: + if hasattr(args, field): + arg_value = getattr(args, field) + if arg_value is not None: # Only override if arg was provided + setattr(self, field, arg_value) + + def save(self, config_path: Optional[Path] = None): + """Save configuration to a file. + + Args: + config_path: Path to save to. If None, saves to ~/.shortfin/server_config.json + """ + if config_path is None: + config_path = Path.home() / ".shortfin" / "server_config.json" + + config_path.parent.mkdir(parents=True, exist_ok=True) + with open(config_path, "w") as f: + json.dump(self.to_dict(), f, indent=2) diff --git a/shortfin/python/shortfin_apps/llm/components/config_structs.py b/shortfin/python/shortfin_apps/llm/components/config_structs.py deleted file mode 100644 index b0bee1b1b..000000000 --- a/shortfin/python/shortfin_apps/llm/components/config_structs.py +++ /dev/null @@ -1,104 +0,0 @@ -# Copyright 2024 Advanced Micro Devices, Inc. -# -# Licensed under the Apache License v2.0 with LLVM Exceptions. -# See https://llvm.org/LICENSE.txt for license information. -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -"""Server configuration management.""" - -import json -import os -from dataclasses import dataclass, field -from pathlib import Path -from typing import Optional - -from dataclasses_json import dataclass_json, Undefined - - -@dataclass_json(undefined=Undefined.RAISE) -@dataclass -class ServerParams: - """Server configuration parameters.""" - - # KV cache configuration - prefix_sharing_algorithm: str = "none" # none or trie - - # Server runtime configuration - host: Optional[str] = None - port: int = 8000 - root_path: Optional[str] = None - timeout_keep_alive: int = 5 - - # Program isolation configuration - program_isolation: str = "per_call" - - # Device configuration - device_ids: list[str] = field(default_factory=list) - amdgpu_async_allocations: bool = False - amdgpu_allocators: Optional[str] = None - - @staticmethod - def load(config_path: Optional[Path] = None) -> "ServerParams": - """Load server configuration from a file or use defaults. - - Args: - config_path: Path to config file. If None, will check standard locations. - - Returns: - ServerParams instance with defaults or loaded values - """ - if config_path is None: - # Check standard locations - config_paths = [ - Path.home() / ".shortfin" / "server_config.json", - Path.home() / ".config" / "shortfin" / "server_config.json", - Path("/etc/shortfin/server_config.json"), - ] - - for path in config_paths: - if path.exists(): - config_path = path - break - - # Start with defaults - params = ServerParams() - - # Override with config file if it exists - if config_path and config_path.exists(): - with open(config_path) as f: - file_params = ServerParams.from_json(f.read()) - # Update only non-None values from file - for field in params.__dataclass_fields__: - file_value = getattr(file_params, field) - if file_value is not None: - setattr(params, field, file_value) - - return params - - def update_from_args(self, args) -> None: - """Update configuration from command line arguments. - - Args: - args: Parsed command line arguments - - Command line arguments take highest priority. - """ - # Only update fields that are present in args - for field in self.__dataclass_fields__: - if hasattr(args, field): - arg_value = getattr(args, field) - if arg_value is not None: # Only override if arg was provided - setattr(self, field, arg_value) - - def save(self, config_path: Optional[Path] = None): - """Save configuration to a file. - - Args: - config_path: Path to save to. If None, saves to ~/.shortfin/server_config.json - """ - if config_path is None: - config_path = Path.home() / ".shortfin" / "server_config.json" - - config_path.parent.mkdir(parents=True, exist_ok=True) - with open(config_path, "w") as f: - json.dump(self.to_dict(), f, indent=2) diff --git a/shortfin/python/shortfin_apps/llm/components/service.py b/shortfin/python/shortfin_apps/llm/components/service.py index 79a426aae..0984a412c 100644 --- a/shortfin/python/shortfin_apps/llm/components/service.py +++ b/shortfin/python/shortfin_apps/llm/components/service.py @@ -18,10 +18,9 @@ ) from .kvcache.trie_attention_cache import TriePagedAttentionCache from .kvcache.page_pool import PagePoolConfig, PagePool, PageInfo -from .config_struct import ModelParams +from .config_struct import ModelParams, ServerParams from .manager import SystemManager from .messages import InferenceExecRequest, InferencePhase, StrobeMessage -from .config_structs import ServerParams from .tokenizer import Tokenizer logger = logging.getLogger(__name__) diff --git a/shortfin/python/shortfin_apps/llm/server.py b/shortfin/python/shortfin_apps/llm/server.py index b98f04fdc..24794eff2 100644 --- a/shortfin/python/shortfin_apps/llm/server.py +++ b/shortfin/python/shortfin_apps/llm/server.py @@ -24,11 +24,10 @@ from .components.generate import ClientGenerateBatchProcess -from .components.config_struct import ModelParams +from .components.config_struct import ModelParams, ServerParams from .components.io_struct import GenerateReqInput from .components.manager import SystemManager from .components.service import GenerateService -from .components.config_structs import ServerParams from .components.tokenizer import Tokenizer From c75be75d19a0937f54725decdacb49a96dee6da3 Mon Sep 17 00:00:00 2001 From: Xida Date: Wed, 29 Jan 2025 21:56:12 +0000 Subject: [PATCH 7/9] docstring improvements --- .../llm/components/config_struct.py | 114 ++++++++++-------- 1 file changed, 62 insertions(+), 52 deletions(-) diff --git a/shortfin/python/shortfin_apps/llm/components/config_struct.py b/shortfin/python/shortfin_apps/llm/components/config_struct.py index 8415edc88..527287b71 100644 --- a/shortfin/python/shortfin_apps/llm/components/config_struct.py +++ b/shortfin/python/shortfin_apps/llm/components/config_struct.py @@ -4,56 +4,12 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -"""Configuration objects. - -Parameters that are intrinsic to a specific model. - -In a typical transformer model, the KV cache is organized similar to (mapped to -our parameter names below): - k = tensor.empty(transformer_block_count, batch_size, seq, - attn_head_count_kv, attn_head_dim) - v = ... - -For context, a popular model has parameters of: - attn_dtype_size = 2 # (fp16) - max_seq_len = 2048 - transformer_block_count = 32 - attn_head_count_kv = 32 - attn_head_dim = 128 # (dim / head_count) - -If paging, then we primarily care about the organization of a single block, where -a block represents a single position in the sequence for a single item in the batch. -Therefore, it will be organized like: - block = torch.empty(transformer_block_count, 2, attn_head_count_kv, attn_head_dim) - -In this scenario, we declare that one block holds the KV cache for all transformer -block layers because it reduces the accounting. As such, for the above example, -a single position in the sequence will be 524,288 bytes, assuming a 2-byte element -type. If we choose to block by block_stride=16 positions, each block will be 8MiB. -Assuming we wanted to dedicate 12GiB to the block cache, this would equate to 1536 -blocks for a total number of sequence positions of 24,576. - -These are well-known numbers but are derived above to give a sense of scale. - -In order to indirect through to the block cache, we have to provide the index map -to specific invocations: - -* Prefill: Prefill is only writing to the blocks from [0:prompt_len], so it will - need write indices of [batch_size, prompt_len // block_stride + 1]. -* Decode step: Decode is auto-regressive, and needs to first compute the new kv - row and then attend over all rows in the cache up to this point in the sequence. - -If wanting to avoid dynamic allocation of transients, we can also pool the index -tables based on the maximum batch size and maximum sequence length. Since all -block cache sizes are well within the range of an i16, we will use that for storage. -Therefore, each batch invocation would need a block lookup table of: - - byte_size = max_batch_size * (max_seq_len // block_stride) * sizeof(int16_t) - -For a max_batch_size of 16, this is 4KiB of block index table lookups per -invocation. We don't have to statically allocate this, but the system is more -predictable if we just reserve what we need. Again, numbers are given to give a -sense of scale only: real workloads will vary. +""" +Configuration objects. + +Classes: +- ModelParams: for reading and managing config keys specified in `config.json` files exported by `python -m sharktank.examples.export_paged_llm_v1` +- ServerParams: for specifying config keys needed by `python -m shortfin_apps.llm.server` """ import json @@ -81,7 +37,55 @@ def _decode_dtype(name: str) -> sfnp.DType: @dataclass_json(undefined=Undefined.RAISE) @dataclass class PagedKVCacheParams: - """Parameters for the paged KV cache.""" + """Parameters for the paged KV cache. + + In a typical transformer model, the KV cache is organized similar to (mapped to + our parameter names below): + k = tensor.empty(transformer_block_count, batch_size, seq, + attn_head_count_kv, attn_head_dim) + v = ... + + For context, a popular model has parameters of: + attn_dtype_size = 2 # (fp16) + max_seq_len = 2048 + transformer_block_count = 32 + attn_head_count_kv = 32 + attn_head_dim = 128 # (dim / head_count) + + If paging, then we primarily care about the organization of a single block, where + a block represents a single position in the sequence for a single item in the batch. + Therefore, it will be organized like: + block = torch.empty(transformer_block_count, 2, attn_head_count_kv, attn_head_dim) + + In this scenario, we declare that one block holds the KV cache for all transformer + block layers because it reduces the accounting. As such, for the above example, + a single position in the sequence will be 524,288 bytes, assuming a 2-byte element + type. If we choose to block by block_stride=16 positions, each block will be 8MiB. + Assuming we wanted to dedicate 12GiB to the block cache, this would equate to 1536 + blocks for a total number of sequence positions of 24,576. + + These are well-known numbers but are derived above to give a sense of scale. + + In order to indirect through to the block cache, we have to provide the index map + to specific invocations: + + * Prefill: Prefill is only writing to the blocks from [0:prompt_len], so it will + need write indices of [batch_size, prompt_len // block_stride + 1]. + * Decode step: Decode is auto-regressive, and needs to first compute the new kv + row and then attend over all rows in the cache up to this point in the sequence. + + If wanting to avoid dynamic allocation of transients, we can also pool the index + tables based on the maximum batch size and maximum sequence length. Since all + block cache sizes are well within the range of an i16, we will use that for storage. + Therefore, each batch invocation would need a block lookup table of: + + byte_size = max_batch_size * (max_seq_len // block_stride) * sizeof(int16_t) + + For a max_batch_size of 16, this is 4KiB of block index table lookups per + invocation. We don't have to statically allocate this, but the system is more + predictable if we just reserve what we need. Again, numbers are given to give a + sense of scale only: real workloads will vary. + """ # Tokens per page. block_seq_stride: int @@ -200,7 +204,13 @@ def human_size(num, suffix="B"): @dataclass_json(undefined=Undefined.RAISE) @dataclass class ServerParams: - """Server configuration parameters.""" + """ + Parameters relevant to a specific server launch. + + shortfin_apps.llm.server accepts an optional json server_config file for this + commandline arguments for setting attributes of this class. + + Commandline args take priority over server_config.json, which takes precedence over defaults defined in this dataclass. + """ # KV cache configuration prefix_sharing_algorithm: str = "none" # none or trie From 854c93391b8a15e9aa1c427dac8e44402afc6c13 Mon Sep 17 00:00:00 2001 From: Xida Date: Wed, 29 Jan 2025 21:59:46 +0000 Subject: [PATCH 8/9] simplify ServerParams config struct --- .../llm/components/config_struct.py | 30 +------------------ 1 file changed, 1 insertion(+), 29 deletions(-) diff --git a/shortfin/python/shortfin_apps/llm/components/config_struct.py b/shortfin/python/shortfin_apps/llm/components/config_struct.py index 527287b71..4f191ad56 100644 --- a/shortfin/python/shortfin_apps/llm/components/config_struct.py +++ b/shortfin/python/shortfin_apps/llm/components/config_struct.py @@ -234,27 +234,13 @@ def load(config_path: Optional[Path] = None) -> "ServerParams": """Load server configuration from a file or use defaults. Args: - config_path: Path to config file. If None, will check standard locations. + config_path: Path to config file. Returns: ServerParams instance with defaults or loaded values """ - if config_path is None: - # Check standard locations - config_paths = [ - Path.home() / ".shortfin" / "server_config.json", - Path.home() / ".config" / "shortfin" / "server_config.json", - Path("/etc/shortfin/server_config.json"), - ] - - for path in config_paths: - if path.exists(): - config_path = path - break - # Start with defaults params = ServerParams() - # Override with config file if it exists if config_path and config_path.exists(): with open(config_path) as f: @@ -264,7 +250,6 @@ def load(config_path: Optional[Path] = None) -> "ServerParams": file_value = getattr(file_params, field) if file_value is not None: setattr(params, field, file_value) - return params def update_from_args(self, args) -> None: @@ -281,16 +266,3 @@ def update_from_args(self, args) -> None: arg_value = getattr(args, field) if arg_value is not None: # Only override if arg was provided setattr(self, field, arg_value) - - def save(self, config_path: Optional[Path] = None): - """Save configuration to a file. - - Args: - config_path: Path to save to. If None, saves to ~/.shortfin/server_config.json - """ - if config_path is None: - config_path = Path.home() / ".shortfin" / "server_config.json" - - config_path.parent.mkdir(parents=True, exist_ok=True) - with open(config_path, "w") as f: - json.dump(self.to_dict(), f, indent=2) From 9cc707fc81ca3485222bd487df54f1d9f25c900d Mon Sep 17 00:00:00 2001 From: Xida Date: Wed, 29 Jan 2025 22:06:21 +0000 Subject: [PATCH 9/9] comments and docstring improvements --- .../python/shortfin_apps/llm/components/config_struct.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/shortfin/python/shortfin_apps/llm/components/config_struct.py b/shortfin/python/shortfin_apps/llm/components/config_struct.py index 4f191ad56..27ae7269b 100644 --- a/shortfin/python/shortfin_apps/llm/components/config_struct.py +++ b/shortfin/python/shortfin_apps/llm/components/config_struct.py @@ -231,7 +231,7 @@ class ServerParams: @staticmethod def load(config_path: Optional[Path] = None) -> "ServerParams": - """Load server configuration from a file or use defaults. + """Create a new ServerParams object by overriding defaults from a `server_config.json` file. Args: config_path: Path to config file. @@ -239,9 +239,7 @@ def load(config_path: Optional[Path] = None) -> "ServerParams": Returns: ServerParams instance with defaults or loaded values """ - # Start with defaults params = ServerParams() - # Override with config file if it exists if config_path and config_path.exists(): with open(config_path) as f: file_params = ServerParams.from_json(f.read()) @@ -260,9 +258,10 @@ def update_from_args(self, args) -> None: Command line arguments take highest priority. """ - # Only update fields that are present in args for field in self.__dataclass_fields__: if hasattr(args, field): arg_value = getattr(args, field) - if arg_value is not None: # Only override if arg was provided + if ( + arg_value is not None + ): # Only override if a cmdline arg of the same name was provided setattr(self, field, arg_value)