Skip to content

Commit

Permalink
Enable running PyTorch models (#207)
Browse files Browse the repository at this point in the history
* refactor to separate TVM specific bits from paged_cache_model

* fix

* Remove engine config change for now

* make mypy happy with TextGenerator impl by Model

* stub

* wip

* wip

* wip

* PT model memory profiling works

* get rid of vllm prepare_inputs

* wip

* model runs but nan output

* mypy improvement

* runs e2e but the result is garbage

* working

* minor

* do sampling by mlc function

* merge fix

* wip parallel sampling

* fix test

* wip

* fix

* wip

* wip

* wip

* attach cache_blocks to model

* change get_num_cache_blocks signature

* wip

* wip

* wip

* refactor

* update for qwen

* mergei fix

* clean

* KV cache refactor to decouple cache blocks and metadata about them

* update for KV refactor

* updated for the latest vllm

* qwen and phi supported

* Make num_shards configuable via engine config

* unify Model and ModelRpcClient classes

* support PT model in server

* properly allocate port

* refactor engine creation

* fix sync point

* do not create executor at each step

* remove dup obtain calls

* fix

* use sample_from_logits

* enable TCP NoDelay option to fix slow socket recv issue

* Replace TCP with Unix domain socket

* clean and add note on RPC overhead

* clean

* RPC process join works

* fix mypy

* merge fix

* wip test fix

* fix

* Properly verify sampling params in api handler

* Create model artifact config before module initialization

* fix engine start

* fix

* black

* properly handle import failure

* add titoken dep

* revert logprob change

* restored tokenizer.is_fast assert but commented out

* fix vocab siz

* properly account for logits storage in memory profiling

* merge fix

* validate num_shards in engine creation

* replace print with structlog

* add peak memory log for tvm as well

* add tokenizer.is_fast warning on creation
  • Loading branch information
masahi authored Feb 22, 2024
1 parent abe93a1 commit a377c3c
Show file tree
Hide file tree
Showing 13 changed files with 850 additions and 64 deletions.
7 changes: 7 additions & 0 deletions serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class MLCServeEngineConfig:
min_decode_steps: int = 32
max_decode_steps: int = 48
init_timeout: int = 120
model_type: str = "tvm" # "tvm", "torch"
num_shards: Optional[int] = None # Need to be specified for if model_type is "torch"

@classmethod
def _from_json(config_cls, json_obj: Dict[Any, Any]):
Expand All @@ -57,6 +59,11 @@ def get_engine_config(dict_config):
assert (engine_config.min_decode_steps > 0) and (engine_config.max_decode_steps > 0)
assert engine_config.max_decode_steps > engine_config.min_decode_steps

if engine_config.model_type == "torch":
assert (
engine_config.num_shards is not None
), "num_shards in MLCServeEngineConfig needs to be provided for PT models."

return engine_config


Expand Down
2 changes: 0 additions & 2 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,6 @@ def detokenize_incrementally(
prefix_begin_offset = generation_sequence.prefix_begin_offset
prefix_end_offset = generation_sequence.prefix_end_offset

assert tokenizer.is_fast

prefix_text = tokenizer.convert_tokens_to_string(
output_tokens[prefix_begin_offset:prefix_end_offset]
)
Expand Down
2 changes: 1 addition & 1 deletion serve/mlc_serve/engine/model_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def get_cache(self) -> KVCache:
The returned value should be passed to Executor.generate_text.
"""

def allocate(self, request_id: RequestId, num_tokens: int, num_sequnces: int):
def allocate(self, request_id: RequestId, num_tokens: int, num_sequences: int):
"""
Allocate cache space for request, raise error if there is no space.
"""
Expand Down
3 changes: 2 additions & 1 deletion serve/mlc_serve/engine/staging_engine_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def run_generation_loop_worker(

try:
model_module = model_module_loader(**model_module_loader_kwargs)
LOG.info("Model is initalized.")
LOG.info("Model is initialized.")
worker = GenerationLoopWorker(model_module=model_module)
except:
LOG.exception("An error raised in model initialization.")
Expand All @@ -370,6 +370,7 @@ def handle_command():
while True:
cmd = command_queue.get()
if isinstance(cmd, ShutdownCommand):
del worker.text_generator
break
elif isinstance(cmd, AddRequestsCommand):
worker.add(cmd.request_states)
Expand Down
17 changes: 17 additions & 0 deletions serve/mlc_serve/model/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from transformers import AutoConfig

from dataclasses import dataclass
from typing import Optional
from pathlib import Path
import os
import json
import inspect
Expand Down Expand Up @@ -57,3 +60,17 @@ def get_model_artifact_config(model_artifact_path):
json_object["paged_kv_cache_type"] = "vllm"

return ModelArtifactConfig._from_json(json_object)


def get_hf_config(model_path: Path) -> AutoConfig:
hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)

if not hasattr(hf_config, "num_key_value_heads") and hasattr(
hf_config, "num_attention_heads"
):
hf_config.num_key_value_heads = hf_config.num_attention_heads

if not hasattr(hf_config, "sliding_window"):
hf_config.sliding_window = None

return hf_config
104 changes: 67 additions & 37 deletions serve/mlc_serve/model/model_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,13 @@ def get_gpu_memory(gpu: int = 0) -> int:


def get_num_cache_blocks(
model,
used_memory_bytes,
block_size,
seq_lens,
num_layers,
num_kv_heads,
head_size,
gpu_memory_utilization=0.9, # the default used by vllm
):
used_memory_bytes = model.profile_memory_usage(seq_lens)
cache_block_size = CacheManager.get_cache_block_size(
block_size, num_layers, num_kv_heads, head_size
)
Expand Down Expand Up @@ -85,22 +83,18 @@ def sample_from_logits(
requests: Sequence[RequestType],
sampling_state: SamplingState,
vocab_size: int,
copy_stream: torch.cuda.Stream,
torch_dtype: torch.dtype,
torch_dev: str,
past_decode_tokens: List[List[int]],
prompt_masks: List[torch.Tensor],
) -> List[TextGenerationResult]:
batch_size = logits.shape[0]
assert batch_size == len(requests)

# Convert to torch tensors if logits are in tvm ndarray
if isinstance(logits, tvm.nd.NDArray):
logits = torch.from_dlpack(logits)

# synchronization point for sampling tensors
# wait until all the tensors are loaded on GPU
torch.cuda.current_stream().wait_stream(copy_stream)

# Logit processing for constraint sampling e.g., JSON Mode
for i, (sequence_id, request) in enumerate(zip(sequence_ids, requests)):
if request.sampling_params.logits_processor is not None:
Expand Down Expand Up @@ -140,6 +134,7 @@ def sample_from_logits(
" or element < 0"
)
logits = torch.from_dlpack(logits)

for i in range(batch_size):
sequence_id = sequence_ids[i]
logits_per_token = logits[i]
Expand All @@ -149,16 +144,14 @@ def sample_from_logits(
# NOTE: Rerun the preparation for simplicity.
# Assume this code path is taken rarely and the recomputation overhead is
# marginal.
with torch.cuda.stream(copy_stream):
new_sampling_state = SamplingState.from_sampling_params(
[sampling_param],
[past_decode_tokens_per_request],
[prompt_mask],
torch_dtype,
torch_dev,
vocab_size,
)
torch.cuda.current_stream().wait_stream(copy_stream)
new_sampling_state = SamplingState.from_sampling_params(
[sampling_param],
[past_decode_tokens_per_request],
[prompt_mask],
torch_dtype,
torch_dev,
vocab_size,
)
maybe_sampling_output: Optional[SamplingOutput] = sample(
torch.unsqueeze(logits_per_token, 0),
new_sampling_state,
Expand All @@ -169,6 +162,7 @@ def sample_from_logits(
logprob_info = maybe_sampling_output.logprob_infos[0]
# Valid sample
request = requests[i]

if maybe_sampling_output is not None:
outputs.extend(
prepare_textgen_result(
Expand Down Expand Up @@ -200,24 +194,39 @@ def prepare_inputs(
all_decode_block_tables,
sliding_window,
is_prefill,
block_size,
num_decode_query_tokens=1,
for_vllm=False,
):
if for_vllm:
torch_int_dtype = torch.long
else:
torch_int_dtype = torch.int

block_tables = []
seq_lens = []
input_ids = []
slot_mapping = []
positions = []
max_num_blocks_per_seq = 0
indices_within_window = []
start_idx = 0
max_prompt_len = -1
max_context_len = -1

for i, (sequence_id, token_ids) in enumerate(zip(sequence_ids, all_token_ids)):
if is_prefill:
input_ids += token_ids
prompt_len = len(token_ids)
seq_lens.append(prompt_len)
positions += range(prompt_len)
slot_mapping += all_slot_mappings[sequence_id]
max_prompt_len = max(max_prompt_len, prompt_len)

if for_vllm:
input_ids.append(token_ids)
positions.append(list(range(prompt_len)))
slot_mapping.append(all_slot_mappings[sequence_id])
else:
input_ids += token_ids
positions += range(prompt_len)
slot_mapping += all_slot_mappings[sequence_id]

if sliding_window:
indices_within_window += range(
Expand All @@ -228,44 +237,65 @@ def prepare_inputs(

else:
seq_len = prompt_lens[i] + len(token_ids)
input_ids += token_ids[-num_decode_query_tokens:]

for i in range(num_decode_query_tokens):
positions.append(seq_len - (num_decode_query_tokens - i))
if for_vllm:
assert num_decode_query_tokens == 1
input_ids.append([token_ids[-1]])
positions.append([seq_len - 1])
slot_mapping.append([all_slot_mappings[sequence_id][-1]])
else:
input_ids += token_ids[-num_decode_query_tokens:]

slot_mapping += all_slot_mappings[sequence_id][-num_decode_query_tokens:]
for i in range(num_decode_query_tokens):
positions.append(seq_len - (num_decode_query_tokens - i))

slot_mapping += all_slot_mappings[sequence_id][-num_decode_query_tokens:]

block_table = all_decode_block_tables[sequence_id]
max_num_blocks_per_seq = max(max_num_blocks_per_seq, len(block_table))
block_tables.append(block_table.get_blocks())

if sliding_window:
seq_lens.append(min(seq_len, sliding_window))
else:
seq_lens.append(seq_len)

max_context_len = max(max_context_len, seq_lens[-1])

def _do_pad(
x: List[List[int]],
max_len: int,
pad_val: int,
) -> List[List[int]]:
def _pad_to_max(x: List[int], max_len: int, pad_val: int) -> List[int]:
assert len(x) <= max_len
return x + [pad_val] * (max_len - len(x))

return [_pad_to_max(x_i, max_len, pad_val) for x_i in x]

if for_vllm and is_prefill:
input_ids = _do_pad(input_ids, max_prompt_len, 0)
positions = _do_pad(positions, max_prompt_len, 0)
slot_mapping = _do_pad(slot_mapping, max_prompt_len, -1)

def to_torch(arr, torch_dtype):
return torch.tensor(arr, dtype=torch_dtype, device="cuda")

input_ids = to_torch(input_ids, torch.int)
positions = to_torch(positions, torch.int)
input_ids = to_torch(input_ids, torch_int_dtype)
positions = to_torch(positions, torch_int_dtype)
seq_lens = to_torch(seq_lens, torch.int)
slot_mapping = to_torch(slot_mapping, torch.int)
slot_mapping = to_torch(slot_mapping, torch_int_dtype)

if is_prefill and sliding_window:
indices_within_window = to_torch(indices_within_window, torch.int)
else:
indices_within_window = None

if not is_prefill:
max_block_table_len = (
max_context_len + block_size - 1
) // block_size

def _pad_to_max(x: List[int], max_len: int) -> List[int]:
return x + [0] * (max_len - len(x))

padded_block_tables = [
_pad_to_max(block_table, max_num_blocks_per_seq)
for block_table in block_tables
]
padded_block_tables = _do_pad(block_tables, max_block_table_len, 0)
block_tables = to_torch(padded_block_tables, torch.int)
else:
block_tables = None
Expand Down
14 changes: 11 additions & 3 deletions serve/mlc_serve/model/paged_cache_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .base import ModelArtifactConfig
from .paged_cache_manager import CacheManager
from .tokenizer import HfTokenizerModule, ConversationTemplate, Tokenizer
from .torch_model import init_torch_model
from .tvm_model import init_tvm_model

from ..engine import MLCServeEngineConfig
Expand Down Expand Up @@ -81,15 +82,22 @@ def __init__(
engine_config: MLCServeEngineConfig,
model_artifact_config: ModelArtifactConfig
):
# TODO(masahi): Make the model type configurable.
model, cache_manager = init_tvm_model(model_artifact_config, engine_config)
if engine_config.model_type == "tvm":
model, cache_manager = init_tvm_model(model_artifact_config, engine_config)
tokenizer_module = HfTokenizerModule(model_artifact_path.joinpath("model"))
elif engine_config.model_type == "torch":
model, cache_manager = init_torch_model(
model_artifact_path, engine_config
)
tokenizer_module = HfTokenizerModule(model_artifact_path)
else:
raise RuntimeError(f"Unknown model type {engine_config.model_type}")

self.engine_config = engine_config
self.model_artifact_config = model_artifact_config
self.text_generator = PagedCacheModelTextGenerator(model)
self.cache_manager = cache_manager

tokenizer_module = HfTokenizerModule(model_artifact_path)
self.tokenizer = tokenizer_module.tokenizer
self.conversation_template = tokenizer_module.conversation_template

Expand Down
12 changes: 9 additions & 3 deletions serve/mlc_serve/model/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,15 +479,15 @@ def adjust_logits(logits: torch.Tensor, sampling_state: SamplingState, vocab_siz

@dataclass
class SamplingOutput:
next_tokens: list[int]
logprob_infos: list[Optional[RawLogprobsInfo]]
next_tokens: np.ndarray
logprob_infos: List[Optional[RawLogprobsInfo]]


def sample(
logits: torch.Tensor,
sampling_state: SamplingState,
check_safety: bool = False,
) -> SamplingOutput:
) -> Optional[SamplingOutput]:
def _is_safe_to_sample(prob_like):
return (
torch.sum(torch.isnan(prob_like) | torch.isinf(prob_like) | (prob_like < 0))
Expand All @@ -504,21 +504,26 @@ def _is_safe_to_sample(prob_like):
)

next_tokens = np.empty((batch_size,), dtype=np.int64)

if sampling_state.has_greedy:
res_greedy = torch.argmax(logits[mask_greedy_t], -1)
np_mask_greedy = mask_greedy_t.cpu().numpy()
next_tokens[np_mask_greedy] = res_greedy.cpu().numpy()

probs_random = None

if sampling_state.has_random:
probs_random = torch.softmax(logits[mask_random_t], dim=-1)

if check_safety and not _is_safe_to_sample(probs_random):
return None

res_random = torch.multinomial(probs_random, 1, True)[:, 0]
np_mask_random = mask_random_t.cpu().numpy()
next_tokens[np_mask_random] = res_random.cpu().numpy()

logprob_infos: List[Optional[RawLogprobsInfo]] = [None] * batch_size

if sampling_state.has_logprob:
# If everything is random sampling, save one extra softmax
if not sampling_state.has_greedy:
Expand All @@ -535,6 +540,7 @@ def _is_safe_to_sample(prob_like):
mask = sampling_state.sampling_tensors.mask_top_logprob
top_tokens = all_top_tokens[mask]
top_logprobs = all_top_logprobs[mask]

for idx, batch_idx in enumerate(sampling_state.logprob_batch_indices):
next_token = next_tokens[batch_idx]
assert sampling_state.sampling_params[batch_idx].logprobs
Expand Down
Loading

0 comments on commit a377c3c

Please sign in to comment.