Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Limit number of sequences #220

Merged
merged 11 commits into from
Feb 23, 2024
3 changes: 3 additions & 0 deletions serve/mlc_serve/engine/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@ class MLCServeEngineConfig:
# TODO(@sunggg): figure out better defaults
use_staging_engine: bool = True
max_num_batched_tokens: int = 4096
max_num_seq: int = 256
max_num_seq_per_request: Optional[int] = None # default to `max_num_seq / 4`
min_decode_steps: int = 32
max_decode_steps: int = 48
init_timeout: int = 120
model_type: str = "tvm" # "tvm", "torch"
num_shards: Optional[int] = None # Need to be specified for if model_type is "torch"
gpu_memory_utilization: float = 0.9

@classmethod
def _from_json(config_cls, json_obj: Dict[Any, Any]):
Expand Down
14 changes: 14 additions & 0 deletions serve/mlc_serve/engine/engine_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,8 @@ class EngineBase:
model_artifact_config: ModelArtifactConfig
max_context_length: int
max_num_batched_tokens: int
max_num_seq: int
max_num_seq_per_request: int
max_decode_steps: int
min_decode_steps: int
kv_cache_size: int
Expand All @@ -426,6 +428,10 @@ def __init__(self, model_module: ModelModule):
), "max_context_length must not be zero"
self.max_context_length = self.model_artifact_config.max_context_length
self.max_num_batched_tokens = model_module.engine_config.max_num_batched_tokens
self.max_num_seq = model_module.engine_config.max_num_seq
self.max_num_seq_per_request = model_module.engine_config.max_num_seq_per_request
if self.max_num_seq_per_request is None:
self.max_num_seq_per_request = self.max_num_seq // 4
self.max_decode_steps = min(
self.cache_manager.get_kv_cache_size(),
model_module.engine_config.max_decode_steps,
Expand Down Expand Up @@ -592,6 +598,14 @@ def try_grow_batch(self, num_new_batched_tokens) -> Optional[int]:
)
return None

current_num_seq = sum(len(s.generation_sequences) for s in self.current_batch.values())
if current_num_seq + len(state.generation_sequences) > self.max_num_seq:
LOG.debug(
"Stop growing the batch due to max number of sequences.",
)
return None


self.queue.popleft()
self.cache_manager.allocate(state.request_id, num_tokens, state.num_sequences)
self.current_batch[state.request_id] = state
Expand Down
6 changes: 6 additions & 0 deletions serve/mlc_serve/engine/staging_engine_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,12 @@ def add(self, request_states: list[RequestState]):
"The prompt is too long for the given set of engine"
" parameters."
)
elif state.num_sequences > self.max_num_seq_per_request:
self.cancelled_requests.append(state)
state.validation_err = ValidationError(
f"The number of sequences ({state.num_sequences}) is greater"
f"than the maximum allowed value ({self.max_num_seq_per_request})"
)
else:
valid_states.append(state)

Expand Down
2 changes: 1 addition & 1 deletion serve/mlc_serve/model/model_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_num_cache_blocks(
num_layers,
num_kv_heads,
head_size,
gpu_memory_utilization=0.9, # the default used by vllm
gpu_memory_utilization,
):
cache_block_size = CacheManager.get_cache_block_size(
block_size, num_layers, num_kv_heads, head_size
Expand Down
11 changes: 10 additions & 1 deletion serve/mlc_serve/model/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,8 @@ def profile_and_init_cache(
hf_config,
num_shards,
max_num_batched_tokens,
max_num_seq,
gpu_memory_utilization,
):
num_kv_heads = hf_config.num_key_value_heads // num_shards
num_hidden_layers = hf_config.num_hidden_layers
Expand All @@ -177,7 +179,9 @@ def profile_and_init_cache(

if max_num_batched_tokens > 0:
LOG.info("Running memory profiling.")
seq_lens = [1] * max_num_batched_tokens
seq_len = max_num_batched_tokens // max_num_seq
seq_lens = [seq_len] * max_num_seq
seq_lens[-1] += max_num_batched_tokens % max_num_seq
used_memory_bytes = profile_memory_usage(
pt_model, seq_lens, num_hidden_layers, hf_config.vocab_size
)
Expand All @@ -187,6 +191,7 @@ def profile_and_init_cache(
hf_config.num_hidden_layers,
num_kv_heads,
head_size,
gpu_memory_utilization,
)
else:
num_blocks = 500
Expand Down Expand Up @@ -423,6 +428,8 @@ def exposed_init_model(
hf_config,
num_shards,
engine_config.max_num_batched_tokens,
engine_config.max_num_seq,
engine_config.gpu_memory_utilization,
)

return num_blocks
Expand Down Expand Up @@ -593,6 +600,8 @@ def __init__(
hf_config,
1,
engine_config.max_num_batched_tokens,
engine_config.max_num_seq,
engine_config.gpu_memory_utilization,
)
self.model_rpc = None

Expand Down
8 changes: 7 additions & 1 deletion serve/mlc_serve/model/tvm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,14 +588,20 @@ def init_tvm_model(
if engine_config.max_num_batched_tokens > 0:
LOG.info("Running memory profiling.")
try:
seq_lens = [1] * engine_config.max_num_batched_tokens
max_num_seq = engine_config.max_num_seq
max_num_batched_tokens = engine_config.max_num_batched_tokens
seq_len = max_num_batched_tokens // max_num_seq
seq_lens = [seq_len] * max_num_seq
seq_lens[-1] += max_num_batched_tokens % max_num_seq

used_memory_bytes = model.profile_memory_usage(seq_lens)
num_blocks = get_num_cache_blocks(
used_memory_bytes,
block_size,
model_artifact_config.num_hidden_layers,
num_kv_heads,
head_size,
engine_config.gpu_memory_utilization,
)
except tvm.error.InternalError:
raise RuntimeError(
Expand Down
4 changes: 4 additions & 0 deletions serve/mlc_serve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ def get_default_mlc_serve_argparser(description="", allow_override=False):
parser.add_argument("--use-sync-engine", action="store_true")
parser.add_argument("--num-sequences-to-sample", type=int, default=1)
parser.add_argument("--max-num-batched-tokens", type=int, default=4096)
parser.add_argument("--max-num-seq", type=int, default=256)
parser.add_argument("--min-decode-steps", type=int, default=32)
parser.add_argument("--max-decode-steps", type=int, default=56)
parser.add_argument("--gpu-memory-utilization", type=float, default=0.9)
parser.add_argument("--debug-logging", action="store_true")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--num-shards", type=int, default=1) # Needed for PT models
Expand Down Expand Up @@ -73,10 +75,12 @@ def create_mlc_engine(args: argparse.Namespace, start_engine=True) -> InferenceE
{
"use_staging_engine": args.use_staging_engine,
"max_num_batched_tokens": args.max_num_batched_tokens,
"max_num_seq": args.max_num_seq,
"min_decode_steps": args.min_decode_steps,
"max_decode_steps": args.max_decode_steps,
"model_type": model_type,
"num_shards": num_shards,
"gpu_memory_utilization": args.gpu_memory_utilization,
}
)

Expand Down
Loading