Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(misc): Profiler support #611

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,4 +219,19 @@ def make_argument_parser() -> argparse.ArgumentParser:
help="""Maximum sequence length that can be captured by the cuda graph for decodign stage.
The default value is 8192. It will turn into eagar mode if encounters a larger value. """,
)
parser.add_argument(
"--profiler",
type=str,
choices=["torch_profile", "nvtx"],
default=None,
help="""Enable profiler support.
This will expose '/profiler_start' and '/profiler_stop' API,
below profiling features will only been enabled in this range.
Options:
'torch_profile' will setup torch.profiler.profile(), traces file will been saved to './trace',
or set by 'LIGHTLLM_TRACE_DIR' env;
'nvtx' will add NVTX marks for external profiler like NVIDIA Nsight System
(you should setup it by youself).
A NVTX named 'LIGHTLLM_PROFILE' will been added within the profiling range.""",
)
return parser
18 changes: 18 additions & 0 deletions lightllm/server/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,24 @@ async def kv_move_status(websocket: WebSocket):
return


@app.get("/profiler_start")
async def profiler_start() -> Response:
if g_objs.args.profiler:
g_objs.httpserver_manager.profiler_msg("start")
return {"status": "ok"}
else:
return JSONResponse({"message": "Profiling support not enabled"}, status_code=500)


@app.get("/profiler_stop")
async def profiler_stop() -> Response:
if g_objs.args.profiler:
g_objs.httpserver_manager.profiler_msg("stop")
return {"status": "ok"}
else:
return JSONResponse({"message": "Profiling support not enabled"}, status_code=500)


@app.on_event("shutdown")
async def shutdown():
logger.info("Received signal to shutdown. Performing graceful shutdown...")
Expand Down
6 changes: 5 additions & 1 deletion lightllm/server/httpserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
from typing import Union, List, Tuple, Dict
from ..tokenizer import get_tokenizer
from ..io_struct import BatchStrOut, AbortReq, FinishStatus
from ..io_struct import BatchStrOut, AbortReq, ProfilerReq, FinishStatus
from ..pd_io_struct import NodeRole
from ..embed_cache.utils import get_shm_name_data, create_shm
from ..req_id_generator import convert_sub_id_to_group_id
Expand Down Expand Up @@ -438,6 +438,10 @@ async def timer_to_pd_master(self):
await asyncio.sleep(10)
logger.info("reconnection to pd_master")

def profiler_msg(self, msg):
abort_req = ProfilerReq(msg)
self.send_to_router.send_pyobj(abort_req)


class ReqStatus:
def __init__(self, req_id, multimodal_params) -> None:
Expand Down
5 changes: 5 additions & 0 deletions lightllm/server/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,3 +344,8 @@ def __init__(self):
class AbortReq:
def __init__(self, group_req_id):
self.group_req_id = group_req_id


class ProfilerReq:
def __init__(self, msg):
self.msg = msg
29 changes: 28 additions & 1 deletion lightllm/server/router/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
from lightllm.utils.infer_utils import calculate_time
from .dynamic_prompt.shared_arr import SharedInt
from .dynamic_prompt.radix_cache import RadixCacheReadOnlyClient
from ..io_struct import BatchTokenIdOut, AbortReq, ReqRunStatus, FinishStatus, ReqDetokenizationState
from ..io_struct import BatchTokenIdOut, AbortReq, ProfilerReq, ReqRunStatus, FinishStatus, ReqDetokenizationState
from .stats import Stats
from .pause_strategy import Fcfs, select_paused_reqs
from ..tokenizer import get_tokenizer
from lightllm.utils.log_utils import init_logger
from lightllm.utils.profiler import LocalProfiler
from lightllm.server.router.token_load import TokenLoad
from lightllm.server.req_id_generator import convert_sub_id_to_group_id
from lightllm.server.metrics.manager import MetricClient
Expand Down Expand Up @@ -79,6 +80,8 @@ def __init__(self, args, router_port, detokenization_port, model_rpc_ports, metr
# 主要是为了防止调度失误,造成 OOM 等错误
self.router_lock = mp.Lock()
g_router_lock.obj = self.router_lock

self.profiler = LocalProfiler(mode=args.profiler, name="lightllm-router") if args.profiler else None
return

async def wait_to_model_ready(self):
Expand Down Expand Up @@ -132,6 +135,7 @@ async def wait_to_model_ready(self):
"mem_fraction": self.args.mem_fraction,
"batch_max_tokens": self.args.batch_max_tokens,
"pd_rpyc_port": self.args.pd_tp_infer_rpyc_ports[rank_id], # 非 pd 模式可以不设置
"profiler": self.profiler if self.world_size == 1 else None,
}
init_model_ret.append(self.model_rpcs[rank_id].init_model(kvargs))

Expand Down Expand Up @@ -335,8 +339,12 @@ async def _prefill_batch(self, batch: Batch):
await self._init_batch(batch)
if not self.is_splitfuse_mode:
# 在 非 splitfuse 模式下,才需要真的执行 prefill 的操作。
if self.profiler:
mark_range = self.profiler.mark_range_start(f"prefill len={batch.input_tokens()}")
rets = [self.model_rpcs[tp_rank].prefill_batch(batch.batch_id) for tp_rank in range(self.world_size)]
ans = await asyncio.gather(*rets)
if self.profiler:
self.profiler.mark_range_end(mark_range)
if self.world_size != 1:
req_to_out_status = obtain(ans[0])
else:
Expand All @@ -355,12 +363,16 @@ async def _prefill_batch(self, batch: Batch):
async def _decode_batch(self, batch: Batch):
start_time = time.time()
self.metric_client.counter_inc("lightllm_batch_inference_count", "decode")
if self.profiler:
mark_range = self.profiler.mark_range_start(f"decode bs={len(batch.reqs)}")
rets = [self.model_rpcs[tp_rank].decode_batch(batch.batch_id) for tp_rank in range(self.world_size)]
ans = await asyncio.gather(*rets)
if self.world_size != 1:
req_to_out_status = obtain(ans[0])
else:
req_to_out_status = ans[0]
if self.profiler:
self.profiler.mark_range_end(mark_range)

self._update_out_status_to_batch(batch, req_to_out_status)
unfinished_req_ids, finished_req_ids = batch.mark_and_get_finished_req_and_preupdate_status()
Expand Down Expand Up @@ -486,9 +498,24 @@ async def loop_for_netio_req(self):
group_req_id = abort_req.group_req_id
await self.abort(group_req_id)
self.send_to_detokenization.send_pyobj(abort_req)
elif isinstance(recv_req, ProfilerReq):
await self.profiler_ops(recv_req.msg)
else:
assert False, f"Error Req Inf {recv_req}"

async def profiler_ops(self, msg):
# assert self.profiler
if self.world_size != 1:
rets = [self.model_rpcs[tp_rank].profiler_ops(msg) for tp_rank in range(self.world_size)]
await asyncio.gather(*rets)

if msg == "start":
self.profiler.start()
elif msg == "stop":
self.profiler.stop()
else:
assert False, "invalid profiler ops"

def clean_up(self):
for model_rpc in self.model_rpcs:
model_rpc.rpc_server_process.kill()
Expand Down
19 changes: 19 additions & 0 deletions lightllm/server/router/model_infer/mode_backend/base_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from lightllm.server.router.model_infer.infer_batch import InferBatch, InferReq, InferSamplingParams, requests_mapping
from lightllm.server.router.token_load import TokenLoad
from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock
from lightllm.utils.profiler import LocalProfiler


class ModeBackend:
Expand Down Expand Up @@ -88,6 +89,15 @@ def init_model(self, kvargs):
self.pd_rpyc_port = kvargs.get("pd_rpyc_port", None)
max_total_token_num = kvargs["max_total_token_num"]

self.profiler = None
if kvargs.get("profiler") is not None:
# when world_size == 1, model and router are in the same process, so use same profiler object
assert world_size == 1
self.profiler = kvargs.get("profiler")
elif self.args.profiler:
# when world_size > 1
self.profiler = LocalProfiler(mode=self.args.profiler, name=f"lightllm-model_backend-{self.tp_rank}")

dist.init_process_group(
"nccl", init_method=f'tcp://127.0.0.1:{kvargs["nccl_port"]}', rank=self.tp_rank, world_size=world_size
)
Expand Down Expand Up @@ -336,3 +346,12 @@ def remove_batch(self, batch_id):
del batch
g_infer_state_lock.release()
return

def profiler_ops(self, msg):
assert self.profiler
if msg == "start":
self.profiler.start()
elif msg == "stop":
self.profiler.stop()
else:
assert False, "invalid profiler ops"
15 changes: 15 additions & 0 deletions lightllm/server/router/model_infer/model_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,11 @@ def exposed_remove_batch(self, batch_id):
def exposed_get_max_total_token_num(self):
return self.backend.get_max_total_token_num()

def exposed_profiler_ops(self, msg):
if self.world_size != 1:
msg = obtain(msg)
return self.backend.profiler_ops(msg)


class ModelRpcClient:
def __init__(self, model_rpc, world_size, rpc_server_process=None):
Expand Down Expand Up @@ -161,6 +166,7 @@ async def _func(*args, **kwargs):
self._filter_batch = async_wrap(self.model.filter_batch)
self._merge_batch = async_wrap(self.model.merge_batch)
self._remove_batch = async_wrap(self.model.remove_batch)
self._profiler_ops = async_wrap(self.model.profiler_ops)
self._get_max_total_token_num = async_wrap(self.model.get_max_total_token_num)
else:
self._init_model = self.model.exposed_init_model
Expand All @@ -171,6 +177,7 @@ async def _func(*args, **kwargs):
self._filter_batch = self.model.exposed_filter_batch
self._merge_batch = self.model.exposed_merge_batch
self._remove_batch = self.model.exposed_remove_batch
self._profiler_ops = self.model.exposed_profiler_ops
self._get_max_total_token_num = self.model.exposed_get_max_total_token_num
return

Expand Down Expand Up @@ -242,6 +249,14 @@ async def get_max_total_token_num(self):
else:
return ans

async def profiler_ops(self, msg):
ans = self._profiler_ops(msg)
if self.use_rpc:
await ans
return
else:
return


def _init_env(args, port, info_queue, mem_queue, router_lock):
# 注册graceful 退出的处理
Expand Down
82 changes: 82 additions & 0 deletions lightllm/utils/profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import os
from typing import Any, Literal, Optional
import torch

from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)


class LocalProfiler:
def __init__(self, mode: Literal["torch_profile", "nvtx"], name: Optional[str] = None):
self.mode: Literal["torch_profile", "nvtx"] = mode
self.name: Optional[str] = name
self.active: bool = False
if self.mode == "torch_profile":
trace_dir = os.getenv("LIGHTLLM_TRACE_DIR", "./trace")
self._torch_profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
with_stack=True, # additional overhead
on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_dir, worker_name=name, use_gzip=True),
)
logger.warning(
"Profiler support (--profiler=XXX) for torch.profile enabled, trace file will been saved to %s",
trace_dir,
)
logger.warning("do not enable this feature in production")
elif self.mode == "nvtx":
self._nvtx_toplevel_mark = "LIGHTLLM_PROFILE"
self._nvtx_toplevel_id = None
logger.warning(
"""Profiler support (--profiler=XXX) for NVTX enabled, toplevel NVTX mark is %s,
use it with external profiling tools""",
self._nvtx_toplevel_mark,
)
logger.warning(
"""e.g. nsys profile --capture-range=nvtx --nvtx-capture=%s --trace=cuda,nvtx
-e NSYS_NVTX_PROFILER_REGISTER_ONLY=0 [--other_nsys_options]
python -m lightllm.server.api_server --profiler=nvtx [--other_lightllm_options]""",
self._nvtx_toplevel_mark,
)
elif self.mode is not None:
assert False, "invalid profiler mode"

def start(self):
if self.active:
logger.error("profiler already started, ignore")
return
logger.warning("Profiler support: profiling start")
self.active = True
if self.mode == "torch_profile":
self._torch_profiler.start()
elif self.mode == "nvtx":
self._nvtx_toplevel_id = torch.cuda.nvtx.range_start(self._nvtx_toplevel_mark)

def stop(self):
if not self.active:
logger.error("profiler not started, ignore")
return
logger.warning("Profiler support: profiling stop")
self.active = False
if self.mode == "torch_profile":
logger.warning("Profiler support: torch_profiler saving trace file, it might take a while...")
self._torch_profiler.stop()
logger.warning("Profiler support: torch_profiler saving done")
elif self.mode == "nvtx":
torch.cuda.nvtx.range_end(self._nvtx_toplevel_id)

def mark_range_start(self, message: str) -> Any:
"return the handle of the range, to be used in mark_range_end()"
if self.active:
# only support for NVTX mode
if self.mode == "nvtx":
return torch.cuda.nvtx.range_start(message)

def mark_range_end(self, handle: Any):
if self.active:
# only support for NVTX mode
if self.mode == "nvtx":
return torch.cuda.nvtx.range_end(handle)
Loading