Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move server-specific config args out of ModelParams and into new ServerParams in config_struct.py #877

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 2 additions & 24 deletions app_tests/integration_tests/llm/server_management.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
"""Handles server lifecycle and configuration."""
import json
import socket
from contextlib import closing
from dataclasses import dataclass
import subprocess
import time
import requests
from pathlib import Path
import sys
from typing import Optional

Expand Down Expand Up @@ -51,7 +49,6 @@ def __init__(self, config: ServerConfig):
self.config = config
self.process: Optional[subprocess.Popen] = None
self.port: Optional[int] = None
self.config_path: Optional[Path] = None

@staticmethod
def find_available_port() -> int:
Expand All @@ -61,42 +58,23 @@ def find_available_port() -> int:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
return s.getsockname()[1]

def _write_config(self) -> Path:
"""Creates server config by extending the exported model config."""
# TODO: eliminate this by moving prefix sharing algorithm to be a cmdline arg of server.py
source_config_path = self.config.artifacts.config_path
server_config_path = (
source_config_path.parent
/ f"server_config_{self.config.prefix_sharing_algorithm}.json"
)

# Read the exported config as base
with open(source_config_path) as f:
config = json.load(f)
config["paged_kv_cache"][
"prefix_sharing_algorithm"
] = self.config.prefix_sharing_algorithm
with open(server_config_path, "w") as f:
json.dump(config, f)
return server_config_path

def start(self) -> None:
"""Starts the server process."""
if self.process is not None:
raise RuntimeError("Server is already running")

self.config_path = self._write_config()
self.port = self.find_available_port()

cmd = [
sys.executable,
"-m",
"shortfin_apps.llm.server",
f"--tokenizer_json={self.config.artifacts.tokenizer_path}",
f"--model_config={self.config_path}",
f"--model_config={self.config.artifacts.config_path}",
f"--vmfb={self.config.artifacts.vmfb_path}",
f"--parameters={self.config.artifacts.weights_path}",
f"--port={self.port}",
f"--prefix_sharing_algorithm={self.config.prefix_sharing_algorithm}",
]
cmd.extend(self.config.device_settings.server_flags)

Expand Down
179 changes: 125 additions & 54 deletions shortfin/python/shortfin_apps/llm/components/config_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,60 +4,19 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""Configuration objects.

Parameters that are intrinsic to a specific model.

In a typical transformer model, the KV cache is organized similar to (mapped to
our parameter names below):
k = tensor.empty(transformer_block_count, batch_size, seq,
attn_head_count_kv, attn_head_dim)
v = ...

For context, a popular model has parameters of:
attn_dtype_size = 2 # (fp16)
max_seq_len = 2048
transformer_block_count = 32
attn_head_count_kv = 32
attn_head_dim = 128 # (dim / head_count)

If paging, then we primarily care about the organization of a single block, where
a block represents a single position in the sequence for a single item in the batch.
Therefore, it will be organized like:
block = torch.empty(transformer_block_count, 2, attn_head_count_kv, attn_head_dim)

In this scenario, we declare that one block holds the KV cache for all transformer
block layers because it reduces the accounting. As such, for the above example,
a single position in the sequence will be 524,288 bytes, assuming a 2-byte element
type. If we choose to block by block_stride=16 positions, each block will be 8MiB.
Assuming we wanted to dedicate 12GiB to the block cache, this would equate to 1536
blocks for a total number of sequence positions of 24,576.

These are well-known numbers but are derived above to give a sense of scale.

In order to indirect through to the block cache, we have to provide the index map
to specific invocations:

* Prefill: Prefill is only writing to the blocks from [0:prompt_len], so it will
need write indices of [batch_size, prompt_len // block_stride + 1].
* Decode step: Decode is auto-regressive, and needs to first compute the new kv
row and then attend over all rows in the cache up to this point in the sequence.

If wanting to avoid dynamic allocation of transients, we can also pool the index
tables based on the maximum batch size and maximum sequence length. Since all
block cache sizes are well within the range of an i16, we will use that for storage.
Therefore, each batch invocation would need a block lookup table of:

byte_size = max_batch_size * (max_seq_len // block_stride) * sizeof(int16_t)

For a max_batch_size of 16, this is 4KiB of block index table lookups per
invocation. We don't have to statically allocate this, but the system is more
predictable if we just reserve what we need. Again, numbers are given to give a
sense of scale only: real workloads will vary.
"""
Configuration objects.

from dataclasses import dataclass
Classes:
- ModelParams: for reading and managing config keys specified in `config.json` files exported by `python -m sharktank.examples.export_paged_llm_v1`
- ServerParams: for specifying config keys needed by `python -m shortfin_apps.llm.server`
"""

import json
import os
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional

import dataclasses_json
from dataclasses_json import dataclass_json, Undefined
Expand All @@ -78,7 +37,55 @@ def _decode_dtype(name: str) -> sfnp.DType:
@dataclass_json(undefined=Undefined.RAISE)
@dataclass
class PagedKVCacheParams:
"""Parameters for the paged KV cache."""
"""Parameters for the paged KV cache.

In a typical transformer model, the KV cache is organized similar to (mapped to
our parameter names below):
k = tensor.empty(transformer_block_count, batch_size, seq,
attn_head_count_kv, attn_head_dim)
v = ...

For context, a popular model has parameters of:
attn_dtype_size = 2 # (fp16)
max_seq_len = 2048
transformer_block_count = 32
attn_head_count_kv = 32
attn_head_dim = 128 # (dim / head_count)

If paging, then we primarily care about the organization of a single block, where
a block represents a single position in the sequence for a single item in the batch.
Therefore, it will be organized like:
block = torch.empty(transformer_block_count, 2, attn_head_count_kv, attn_head_dim)

In this scenario, we declare that one block holds the KV cache for all transformer
block layers because it reduces the accounting. As such, for the above example,
a single position in the sequence will be 524,288 bytes, assuming a 2-byte element
type. If we choose to block by block_stride=16 positions, each block will be 8MiB.
Assuming we wanted to dedicate 12GiB to the block cache, this would equate to 1536
blocks for a total number of sequence positions of 24,576.

These are well-known numbers but are derived above to give a sense of scale.

In order to indirect through to the block cache, we have to provide the index map
to specific invocations:

* Prefill: Prefill is only writing to the blocks from [0:prompt_len], so it will
need write indices of [batch_size, prompt_len // block_stride + 1].
* Decode step: Decode is auto-regressive, and needs to first compute the new kv
row and then attend over all rows in the cache up to this point in the sequence.

If wanting to avoid dynamic allocation of transients, we can also pool the index
tables based on the maximum batch size and maximum sequence length. Since all
block cache sizes are well within the range of an i16, we will use that for storage.
Therefore, each batch invocation would need a block lookup table of:

byte_size = max_batch_size * (max_seq_len // block_stride) * sizeof(int16_t)

For a max_batch_size of 16, this is 4KiB of block index table lookups per
invocation. We don't have to statically allocate this, but the system is more
predictable if we just reserve what we need. Again, numbers are given to give a
sense of scale only: real workloads will vary.
"""

# Tokens per page.
block_seq_stride: int
Expand All @@ -91,8 +98,6 @@ class PagedKVCacheParams:
# Default: 256
device_block_count: int

prefix_sharing_algorithm: str = "none" # currently supporting none and trie


@dataclass_json(undefined=Undefined.RAISE)
@dataclass
Expand Down Expand Up @@ -194,3 +199,69 @@ def human_size(num, suffix="B"):
return f"{num:3.1f}{unit}{suffix}"
num /= 1024.0
return f"{num:.1f}Yi{suffix}"


@dataclass_json(undefined=Undefined.RAISE)
@dataclass
class ServerParams:
"""
Parameters relevant to a specific server launch.

shortfin_apps.llm.server accepts an optional json server_config file for this + commandline arguments for setting attributes of this class.

Commandline args take priority over server_config.json, which takes precedence over defaults defined in this dataclass.
"""

# KV cache configuration
prefix_sharing_algorithm: str = "none" # none or trie

# Server runtime configuration
host: Optional[str] = None
port: int = 8000
root_path: Optional[str] = None
timeout_keep_alive: int = 5

# Program isolation configuration
program_isolation: str = "per_call"

# Device configuration
device_ids: list[str] = field(default_factory=list)
amdgpu_async_allocations: bool = False
amdgpu_allocators: Optional[str] = None

@staticmethod
def load(config_path: Optional[Path] = None) -> "ServerParams":
"""Create a new ServerParams object by overriding defaults from a `server_config.json` file.

Args:
config_path: Path to config file.

Returns:
ServerParams instance with defaults or loaded values
"""
params = ServerParams()
if config_path and config_path.exists():
with open(config_path) as f:
file_params = ServerParams.from_json(f.read())
# Update only non-None values from file
for field in params.__dataclass_fields__:
file_value = getattr(file_params, field)
if file_value is not None:
setattr(params, field, file_value)
return params

def update_from_args(self, args) -> None:
"""Update configuration from command line arguments.

Args:
args: Parsed command line arguments

Command line arguments take highest priority.
"""
for field in self.__dataclass_fields__:
if hasattr(args, field):
arg_value = getattr(args, field)
if (
arg_value is not None
): # Only override if a cmdline arg of the same name was provided
setattr(self, field, arg_value)
10 changes: 6 additions & 4 deletions shortfin/python/shortfin_apps/llm/components/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
)
from .kvcache.trie_attention_cache import TriePagedAttentionCache
from .kvcache.page_pool import PagePoolConfig, PagePool, PageInfo
from .config_struct import ModelParams
from .config_struct import ModelParams, ServerParams
from .manager import SystemManager
from .messages import InferenceExecRequest, InferencePhase, StrobeMessage
from .tokenizer import Tokenizer
Expand Down Expand Up @@ -47,6 +47,7 @@ def __init__(
sysman: SystemManager,
tokenizer: Tokenizer,
model_params: ModelParams,
server_params: "ServerParams",
program_isolation: str = "per_call",
):
self.name = name
Expand All @@ -55,6 +56,7 @@ def __init__(
self.sysman = sysman
self.tokenizer = tokenizer
self.model_params = model_params
self.server_params = server_params
self.inference_parameters: list[sf.BaseProgramParameters] = []
self.inference_modules: list[sf.ProgramModule] = []

Expand All @@ -71,19 +73,19 @@ def __init__(
page_pool = PagePool(
devices=self.main_fiber.devices_dict.values(), config=page_pool_config
)
if model_params.paged_kv_cache.prefix_sharing_algorithm == "trie":
if server_params.prefix_sharing_algorithm == "trie":
self.page_cache = TriePagedAttentionCache(
page_pool=page_pool,
tokens_per_page=model_params.paged_kv_cache.block_seq_stride,
)
elif model_params.paged_kv_cache.prefix_sharing_algorithm == "none":
elif server_params.prefix_sharing_algorithm == "none":
self.page_cache = BasePagedAttentionCache(
page_pool=page_pool,
tokens_per_page=model_params.paged_kv_cache.block_seq_stride,
)
else:
raise ValueError(
f"Unknown model_params.paged_kv_cache.prefix_sharing_algorithm {model_params.paged_kv_cache.prefix_sharing_algorithm}. Currently only supporting 'trie' and 'none'."
f"Unknown prefix_sharing_algorithm {server_params.prefix_sharing_algorithm}. Currently only supporting 'trie' and 'none'."
)

self.program_isolation = PROG_ISOLATIONS[program_isolation]
Expand Down
28 changes: 23 additions & 5 deletions shortfin/python/shortfin_apps/llm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


from .components.generate import ClientGenerateBatchProcess
from .components.config_struct import ModelParams
from .components.config_struct import ModelParams, ServerParams
from .components.io_struct import GenerateReqInput
from .components.manager import SystemManager
from .components.service import GenerateService
Expand Down Expand Up @@ -112,12 +112,18 @@ def get_eos_from_tokenizer_config(json_path):


def configure(args) -> SystemManager:
# Load server configuration with priority: command line > config file > defaults
server_params = ServerParams.load(
args.server_config if hasattr(args, "server_config") else None
)
server_params.update_from_args(args)

# Setup system (configure devices, etc).
sysman = SystemManager(
device=args.device,
device_ids=args.device_ids,
async_allocs=args.amdgpu_async_allocations,
amdgpu_allocators=args.amdgpu_allocators,
device_ids=server_params.device_ids,
async_allocs=server_params.amdgpu_async_allocations,
amdgpu_allocators=server_params.amdgpu_allocators,
)

# Setup each service we are hosting.
Expand All @@ -131,7 +137,8 @@ def configure(args) -> SystemManager:
sysman=sysman,
tokenizer=tokenizer,
model_params=model_params,
program_isolation=args.isolation,
server_params=server_params,
program_isolation=server_params.program_isolation,
)
sm.load_inference_module(args.vmfb)
sm.load_inference_parameters(*args.parameters, parameter_scope="model")
Expand Down Expand Up @@ -215,6 +222,17 @@ def main(argv, log_config=uvicorn.config.LOGGING_CONFIG):
default=None,
help="Allocator to use during VMFB invocation.",
)
parser.add_argument(
"--server_config",
type=Path,
help="Path to server configuration file",
)
parser.add_argument(
"--prefix_sharing_algorithm",
type=str,
choices=["none", "trie"],
help="Algorithm to use for prefix sharing in KV cache",
)
args = parser.parse_args(argv)

if args.tokenizer_config_json is None:
Expand Down
Loading