Skip to content

Commit

Permalink
Integrate Flash-Decoding into engine (#181)
Browse files Browse the repository at this point in the history
* test stub

* wip

* wip

* wip

* compiled

* wip

* fix

* fix

* wip, decode with flash decoding works

* all work

* add paged_kv_cache_type option

* read kv_type from artifact

* black

* refactor attention backend

* minor clean up

* Integrate flash-decoding into mlc-serve

* remove --use-vllm-attention

* wip decode_multi_query integration

* temp handling for multi-query logits

* remove tmp support for multi-query decode

* typo

* use block size 128 or 64 when possible

* remove unused var

* merge fix
  • Loading branch information
masahi authored Feb 12, 2024
1 parent ab14322 commit edf8d27
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 61 deletions.
53 changes: 24 additions & 29 deletions mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import json
import os
import pickle
import shutil
from dataclasses import asdict, dataclass, field, fields
from typing import Any, Dict, Optional

Expand Down Expand Up @@ -157,8 +156,9 @@ class BuildArgs:
pdb: bool
If set, drop into a pdb debugger on error.
use_vllm_attention: bool
Use vLLM paged KV cache and attention kernel, only relevant when enable_batching=True.
paged_kv_cache_type: str
The type of paged KV cache to use, only relevant when enable_batching=True.
Currently "vllm" and "flash-decoding" are supported.
"""
model: str = field(
default="auto",
Expand Down Expand Up @@ -392,19 +392,8 @@ class BuildArgs:
"action": "store_true",
},
)
# TODO(masahi): Remove the use of this option with paged_kv_cache_type
use_vllm_attention: bool = field(
default=False,
metadata={
"help": (
"Use vLLM paged KV cache and attention kernel, only relevant when "
"enable_batching=True."
),
"action": "store_true",
},
)
paged_kv_cache_type: str = field(
default="vllm",
default="",
metadata={"help": "The type of paged KV cache, either vllm or flash-decoding"},
)

Expand Down Expand Up @@ -462,12 +451,18 @@ def _parse_args(parsed) -> argparse.Namespace:
utils.parse_target(parsed)
utils.argparse_postproc_common(parsed)

if parsed.use_vllm_attention:
assert parsed.enable_batching, "--enable_batching is required for using vLLM attention."
assert parsed.target_kind == "cuda", "vLLM attention is only supported for CUDA."
assert tvm.get_global_func(
"tvm.contrib.vllm.single_query_cached_kv_attention", True
), "TVM needs to be built with -DUSE_VLLM=ON."
if parsed.paged_kv_cache_type in ["vllm", "flash-decoding"]:
assert parsed.enable_batching, "--enable_batching is required for using vLLM or Flash-Decoding."
assert parsed.target_kind == "cuda", "vLLM and Flash-Decoding are only supported for CUDA."

if parsed.paged_kv_cache_type == "vllm":
assert tvm.get_global_func(
"tvm.contrib.vllm.single_query_cached_kv_attention", True
), "TVM needs to be built with -DUSE_VLLM=ON to use vLLM."
elif parsed.paged_kv_cache_type == "flash-decoding":
assert tvm.get_global_func(
"tvm.contrib.flash_attn.flash_decoding_with_paged_kvcache", True
), "TVM needs to be built with -DUSE_CUTLASS=ON to use Flash-Decoding."

model_name = [
parsed.model,
Expand Down Expand Up @@ -588,20 +583,20 @@ def mod_transform_before_build(
"decode",
]

if not args.use_vllm_attention:
model_names += [
"create_kv_cache",
"softmax_with_temperature",
"get_metadata",
]
else:
if args.paged_kv_cache_type in ["vllm", "flash-decoding"]:
# This is equivalent to prefill but without KV cache. It is used for
# determining the number of paged cache blocks that can be allocated.
model_names.append("evaluate")
model_names.append("evaluate_multi_query")

if args.paged_kv_cache_type == "flash-decoding":
model_names.append("decode_multi_query")
else:
model_names += [
"create_kv_cache",
"softmax_with_temperature",
"get_metadata",
]

if args.sep_embed:
model_names = ["embed", "prefill_with_embed"] + model_names[1:]
Expand Down Expand Up @@ -879,7 +874,7 @@ def build_model_from_args(args: argparse.Namespace):
"mixtral": llama,
}

if args.use_vllm_attention:
if args.paged_kv_cache_type in ["vllm", "flash-decoding"]:
model_generators["llama"] = llama_batched_vllm
model_generators["mistral"] = llama_batched_vllm
model_generators["mixtral"] = llama_batched_vllm
Expand Down
4 changes: 4 additions & 0 deletions serve/mlc_serve/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ class ModelArtifactConfig:
model_artifact_path: Optional[str] = None
num_shards: Optional[int] = None
quantization: Optional[str] = None
paged_kv_cache_type: Optional[str] = None
model_type: Optional[str] = None
library_name: Optional[str] = None
max_context_length: Optional[int] = None
Expand Down Expand Up @@ -49,4 +50,7 @@ def get_model_artifact_config(model_artifact_path):
with open(config_file_path, mode="rt", encoding="utf-8") as f:
json_object.update(json.load(f))

if not "paged_kv_cache_type" in json_object:
json_object["paged_kv_cache_type"] = "vllm"

return ModelArtifactConfig._from_json(json_object)
14 changes: 10 additions & 4 deletions serve/mlc_serve/model/model_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def get_gpu_memory(gpu: int = 0) -> int:

def get_num_cache_blocks(
model,
block_size,
seq_lens,
num_layers,
num_kv_heads,
Expand All @@ -39,7 +40,7 @@ def get_num_cache_blocks(
):
used_memory_bytes = model.profile_memory_usage(seq_lens)
cache_block_size = CacheManager.get_cache_block_size(
num_layers, num_kv_heads, head_size
block_size, num_layers, num_kv_heads, head_size
)
total_vram = get_gpu_memory()
return int(
Expand Down Expand Up @@ -196,6 +197,7 @@ def prepare_inputs(
all_decode_block_tables,
sliding_window,
is_prefill,
num_decode_query_tokens=1,
):
block_tables = []
seq_lens = []
Expand All @@ -222,13 +224,17 @@ def prepare_inputs(
start_idx += prompt_len

else:
input_ids.append(token_ids[-1])
seq_len = prompt_lens[i] + len(token_ids)
positions.append(seq_len - 1)
input_ids += token_ids[-num_decode_query_tokens:]

for i in range(num_decode_query_tokens):
positions.append(seq_len - (num_decode_query_tokens - i))

slot_mapping += all_slot_mappings[sequence_id][-num_decode_query_tokens:]

block_table = all_decode_block_tables[sequence_id]
max_num_blocks_per_seq = max(max_num_blocks_per_seq, len(block_table))
block_tables.append(block_table.get_blocks())
slot_mapping.append(all_slot_mappings[sequence_id][-1])

if sliding_window:
seq_lens.append(min(seq_len, sliding_window))
Expand Down
8 changes: 4 additions & 4 deletions serve/mlc_serve/model/paged_cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,10 @@ def __init__(


class CacheManager:
block_size: int = 16

@staticmethod
def get_cache_block_size(num_layers, num_heads, head_size):
def get_cache_block_size(block_size, num_layers, num_heads, head_size):
# Taken from vllm/worker/cache_engine.py
key_cache_block = CacheManager.block_size * num_heads * head_size
key_cache_block = block_size * num_heads * head_size
value_cache_block = key_cache_block
total = num_layers * (key_cache_block + value_cache_block)
dtype_size = 2 # fp16
Expand All @@ -133,9 +131,11 @@ def get_cache_block_size(num_layers, num_heads, head_size):
def __init__(
self,
num_blocks: int,
block_size: int,
sliding_window: Optional[int] = None,
):
self.num_blocks = num_blocks
self.block_size = block_size
self.free_blocks = list(range(num_blocks))
self.kv_cache_info = KVCacheInfo(self.block_size)
self.token_counts = dict[SequenceId, int]()
Expand Down
89 changes: 65 additions & 24 deletions serve/mlc_serve/model/tvm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def _prepare_inputs(
all_decode_block_tables,
sliding_window,
is_prefill,
num_decode_query_tokens=1,
):
(
input_ids,
Expand All @@ -111,13 +112,17 @@ def _prepare_inputs(
all_decode_block_tables,
sliding_window,
is_prefill,
num_decode_query_tokens,
)

if block_tables is not None:
block_tables = tvm.nd.from_dlpack(block_tables)
if indices_within_window is not None:
indices_within_window = tvm.nd.from_dlpack(indices_within_window)

if not is_prefill and num_decode_query_tokens > 1:
input_ids = torch.reshape(input_ids, (-1, num_decode_query_tokens))

return (
tvm.nd.from_dlpack(input_ids),
tvm.nd.from_dlpack(positions),
Expand All @@ -131,8 +136,10 @@ def _prepare_inputs(
class Model:
def __init__(
self,
config,
dev,
config: ModelArtifactConfig,
dev: tvm.runtime.Device,
block_size: int,
copy_blocks_func_name: str,
):
self.mod, self.params, self.disco_session = get_tvm_model(config, dev)
self.dev = dev
Expand All @@ -152,7 +159,7 @@ def __init__(
self.torch_dev: str = "cuda"

if self.sliding_window:
self.block_sliding_window = self.sliding_window // CacheManager.block_size
self.block_sliding_window = self.sliding_window // block_size
else:
self.block_sliding_window = None

Expand All @@ -162,7 +169,7 @@ def __init__(
)
else:
self.copy_cache_blocks_func = tvm.get_global_func(
"tvm.contrib.vllm.copy_blocks"
copy_blocks_func_name,
)

self.cache_blocks = None
Expand Down Expand Up @@ -328,6 +335,9 @@ def generate(
all_token_ids = []
sequence_ids = []
prompt_lens = []
# TODO(masahi, yelite): Update this when a new request type for speculative decoding
# is implemented.
num_decode_query_tokens = 1
sampling_params = []
past_decode_tokens = []

Expand Down Expand Up @@ -383,6 +393,7 @@ def generate(
cache.decode_block_tables,
self.sliding_window,
is_prefill,
num_decode_query_tokens,
)

input_shape = input_ids.shape
Expand Down Expand Up @@ -425,15 +436,26 @@ def generate(
if self.disco_session:
block_tables = copy_to_worker_0(self.disco_session, block_tables)

out = self.mod["decode"](
input_ids,
positions,
seq_lens,
self.cache_blocks,
slot_mapping,
block_tables,
self.params,
)
if num_decode_query_tokens is not None and num_decode_query_tokens > 1:
out = self.mod["decode_multi_query"](
input_ids,
positions,
seq_lens,
self.cache_blocks,
slot_mapping,
block_tables,
self.params,
)
else:
out = self.mod["decode"](
input_ids,
positions,
seq_lens,
self.cache_blocks,
slot_mapping,
block_tables,
self.params,
)

if self.disco_session:
logits, _ = out.debug_get_from_remote(0)
Expand Down Expand Up @@ -461,6 +483,10 @@ def generate(
self.copy_cache_blocks_func(self.cache_blocks, block_mapping)
cache.pending_copy_from_to = []

if len(logits.shape) == 3:
# TODO(masahi, yelite): Proper logic for handling multi-query logits (speculative decoding).
return []

return sample_from_logits(
logits,
sequence_ids,
Expand All @@ -479,23 +505,39 @@ def init_tvm_model(
) -> Tuple[TextGenerator, CacheManager]:
dev = tvm.device("cuda", 0)

model = Model(model_artifact_config, dev)

if model_artifact_config.num_shards > 1:
model.disco_session.sync_worker_0()

num_kv_heads = (
model_artifact_config.num_key_value_heads // model_artifact_config.num_shards
)
head_size = (
model_artifact_config.hidden_size // model_artifact_config.num_attention_heads
)

if model_artifact_config.paged_kv_cache_type == "flash-decoding":
allocate_func_name = "tvm.contrib.flash_attn.allocate_kv_cache"
copy_blocks_func_name = "tvm.contrib.flash_attn.copy_blocks"
# This needs to match with the model definition in llama_batched_vllm.py
if head_size <= 64:
block_size = 256
elif head_size <= 128:
block_size = 128
else:
block_size = 64
else:
allocate_func_name = "tvm.contrib.vllm.allocate_kv_cache"
copy_blocks_func_name = "tvm.contrib.vllm.copy_blocks"
block_size = 16

model = Model(model_artifact_config, dev, block_size, copy_blocks_func_name)

if model_artifact_config.num_shards > 1:
model.disco_session.sync_worker_0()

if engine_config.max_num_batched_tokens > 0:
LOG.info("Running memory profiling.")
try:
num_blocks = get_num_cache_blocks(
model,
block_size,
[1] * engine_config.max_num_batched_tokens,
model_artifact_config.num_hidden_layers,
num_kv_heads,
Expand All @@ -509,7 +551,7 @@ def init_tvm_model(
else:
num_blocks = 500

num_cache_slots = num_blocks * CacheManager.block_size
num_cache_slots = num_blocks * block_size

if num_cache_slots <= engine_config.max_num_batched_tokens:
raise RuntimeError(
Expand All @@ -523,25 +565,24 @@ def init_tvm_model(
LOG.info(f"Using {num_blocks} cache blocks.")

if model.disco_session:
init_cache_func = model.disco_session.get_global_func(
"tvm.contrib.vllm.allocate_kv_cache"
)
init_cache_func = model.disco_session.get_global_func(allocate_func_name)
else:
init_cache_func = tvm.get_global_func("tvm.contrib.vllm.allocate_kv_cache")
init_cache_func = tvm.get_global_func(allocate_func_name)

try:
model.cache_blocks = init_cache_func(
head_size,
model_artifact_config.num_hidden_layers,
num_kv_heads,
CacheManager.block_size,
block_size,
num_blocks,
)
except tvm.error.InternalError:
raise RuntimeError(f"Failed to allocate {num_blocks} cache blocks.")

cache_manager = CacheManager(
num_blocks,
block_size,
model_artifact_config.sliding_window,
)

Expand Down

0 comments on commit edf8d27

Please sign in to comment.