diff --git a/.gitignore b/.gitignore index 3878c35..1001d78 100644 --- a/.gitignore +++ b/.gitignore @@ -165,7 +165,7 @@ cython_debug/ # Vim *.swp -.json +*.json token_usage/ run_all.sh diff --git a/format.sh b/format.sh old mode 100644 new mode 100755 diff --git a/skythought/skythought_evals/batch/__init__.py b/skythought/skythought_evals/batch/__init__.py new file mode 100644 index 0000000..868cb18 --- /dev/null +++ b/skythought/skythought_evals/batch/__init__.py @@ -0,0 +1,13 @@ +__all__ = [] + +from .engines import init_engine_from_config +from .pipeline import Pipeline +from .workload import ( + EvalWorkload, +) + +__all__ = [ + "Pipeline", + "init_engine_from_config", + "EvalWorkload", +] diff --git a/skythought/skythought_evals/batch/engines/__init__.py b/skythought/skythought_evals/batch/engines/__init__.py new file mode 100644 index 0000000..2fef759 --- /dev/null +++ b/skythought/skythought_evals/batch/engines/__init__.py @@ -0,0 +1,10 @@ +"""LLM Engines.""" + +__all__ = [] + +from .initializer import EngineInitializerBase, init_engine_from_config + +__all__ = [ + "EngineInitializerBase", + "init_engine_from_config", +] diff --git a/skythought/skythought_evals/batch/engines/base.py b/skythought/skythought_evals/batch/engines/base.py new file mode 100644 index 0000000..2c12759 --- /dev/null +++ b/skythought/skythought_evals/batch/engines/base.py @@ -0,0 +1,22 @@ +"""Engine base.""" + +from typing import Any, AsyncGenerator, Dict + +import numpy as np + + +class EngineBase: + """Base class for engines.""" + + async def __call__( + self, batch: Dict[str, np.ndarray] + ) -> AsyncGenerator[Dict[str, Any], None]: + """Call the LLM engine asynchronously to process a Ray Data batch. + + Args: + batch: The batch. + + Yields: + The output. + """ + raise NotImplementedError diff --git a/skythought/skythought_evals/batch/engines/initializer.py b/skythought/skythought_evals/batch/engines/initializer.py new file mode 100644 index 0000000..9bdf492 --- /dev/null +++ b/skythought/skythought_evals/batch/engines/initializer.py @@ -0,0 +1,264 @@ +"""Engine initializers. +Note that this file should not import any engine dependent modeules, such as +vLLM, because the engine initializer is used in the driver node which may +not have GPUs. +""" + +import os +from pathlib import Path +from typing import Any, Dict, Optional, Union + +import yaml + +from ..utils import ( + download_model_from_hf, + update_dict_recursive, +) +from ..workload import EvalWorkload +from .base import EngineBase + + +class EngineInitializerBase: + """Base class for engine initializer. + + Args: + model_id: The model id. + accelerator_type: The accelerator type. + engine_kwargs: The engine specific configurations. + ray_env_vars: The Ray runtime environment + """ + + use_ray_placement_group: bool = False + + def __init__( + self, + model_id: str, + accelerator_type: str, + engine_kwargs: Dict[str, Any], + lora_adapter: Optional[str] = None, + ray_env_vars: Dict[str, Any] = None, + ): + self._model = model_id + self._accelerator_type = accelerator_type + self._ray_env_vars = ray_env_vars or {} + self.lora_adapter = lora_adapter + self.engine_kwargs = engine_kwargs + + @property + def model(self) -> str: + return self._model + + @property + def accelerator_type(self) -> str: + return self._accelerator_type + + @property + def ray_env_vars(self) -> Dict[str, str]: + return self._ray_env_vars + + @property + def num_gpus(self) -> int: + """The number of GPUs used per engine.""" + raise NotImplementedError + + @property + def max_model_len(self) -> Optional[int]: + """The maximum model length set by the engine.""" + return None + + def get_engine_cls(self) -> EngineBase: + """Get the engine class. + + Returns: + The engine class. + """ + raise NotImplementedError + + def get_engine_constructor_args(self, workload: EvalWorkload) -> Dict[str, Any]: + """Get the engine constructor arguments. + + Args: + workload: The workload that the engine will process. + + Returns: + The engine constructor keyword arguments. + """ + raise NotImplementedError + + +class vLLMEngineInitializer(EngineInitializerBase): + use_ray_placement_group: bool = False + + def __init__( + self, + model_id: str, + accelerator_type: str, + engine_kwargs: Dict[str, Any], + lora_adapter: Optional[str] = None, + ray_env_vars: Dict[str, Any] = None, + ): + super().__init__( + model_id, accelerator_type, engine_kwargs, lora_adapter, ray_env_vars + ) + + # Override vLLM default configs. Note that this is only effective + # when the config is not set by users. + self.engine_kwargs.setdefault("gpu_memory_utilization", 0.95) + self.engine_kwargs.setdefault("use_v2_block_manager", True) + self.engine_kwargs.setdefault("enable_prefix_caching", False) + self.engine_kwargs.setdefault("enforce_eager", False) + self.engine_kwargs.setdefault("pipeline_parallel_size", 1) + self.engine_kwargs.setdefault("max_num_seqs", 256) + self.engine_kwargs.setdefault("tensor_parallel_size", 1) + self.engine_kwargs.setdefault("max_logprobs", 0) + self.engine_kwargs.setdefault("distributed_executor_backend", "mp") + + # Set engine environment variables. + self._ray_env_vars.setdefault("VLLM_ATTENTION_BACKEND", "FLASH_ATTN") + self._ray_env_vars.setdefault("ENABLE_ANYSCALE_PREFIX_OPTIMIZATIONS", "0") + # FIXME: This should already be deprecated and can be removed. + self._ray_env_vars.setdefault("VLLM_DISABLE_LOGPROBS", "1") + for key, value in self._ray_env_vars.items(): + os.environ[key] = str(value) + + def get_engine_cls(self): + from .vllm_engine import AsyncLLMPredictor + + return AsyncLLMPredictor + + @property + def num_gpus(self) -> int: + assert "tensor_parallel_size" in self.engine_kwargs + assert "pipeline_parallel_size" in self.engine_kwargs + tp_size = self.engine_kwargs["tensor_parallel_size"] + pp_size = self.engine_kwargs["pipeline_parallel_size"] + return tp_size * pp_size + + @property + def max_model_len(self) -> Optional[int]: + """The maximum model length set by the engine.""" + return self.engine_kwargs.get("max_model_len", None) + + def get_engine_constructor_args(self, workload: EvalWorkload): + from vllm import PoolingParams, SamplingParams + from vllm.config import PoolerConfig + + constructor_kwargs = { + "model": self.model, + "lora_adapter": self.lora_adapter, + } + + if sampling_params := workload.sampling_params: + # Sampling params is given: Auto-regressive generation. + # In this case, we need to set max_tokens and max_model_len. + + max_tokens = sampling_params.get("max_tokens", None) + if max_tokens is None: + raise ValueError("max_tokens is required for vLLM engine.") + + vllm_sampling_params = SamplingParams(**workload.sampling_params) + vllm_sampling_params.max_tokens = max_tokens + vllm_sampling_params.detokenize = False + constructor_kwargs["params"] = vllm_sampling_params + + if ( + "max_model_len" not in self.engine_kwargs + and workload.max_tokens_in_prompt < 0 + ): + raise ValueError( + "Neither max_tokens_in_prompt nor max_model_len is set. If you " + "intend to let the pipeline infer max_tokens_in_prompt but got this error, " + "it is either because the workload has not been tokenized, or the " + "workload bypass the tokenizer but does not set max_tokens_in_prompt by itself." + ) + + # Use max_tokens_in_prompt + max_tokens as the max_model_len. max_tokens_in_prompt + # is either inferred by materializing tokenized dataset, set by the workload, or + # set by the engine. + self.engine_kwargs["max_model_len"] = ( + workload.max_tokens_in_prompt + max_tokens + ) + else: + # Sampling params is not given: Embedding workload. + # In this case, we need to set pooling_params and task. + + if workload.pooling_params is None: + raise ValueError( + "pooling_params is required for vLLM engine for embedding workload." + ) + constructor_kwargs["params"] = PoolingParams(**workload.pooling_params) + constructor_kwargs["task"] = "embed" + + # Construct PoolerConfig if override_pooler_config is specified. + if pooler_config := self.engine_kwargs.get("override_pooler_config", None): + self.engine_kwargs["override_pooler_config"] = PoolerConfig( + **pooler_config + ) + + constructor_kwargs.update(self.engine_kwargs) + return constructor_kwargs + + +def init_engine_from_config( + config: Union[Dict[str, Any], str], override: Optional[Dict[str, Any]] = None +) -> EngineInitializerBase: + """Initialize an engine initializer from a config file or a config dict. + + Args: + config: A config file (in YAML) or a config dict. It should include + the following keys: "engine", backend engine to use; "model", + model to use; "accelerator_type", the GPU type; "configs", + the engine specific configurations. + override: Override values in config["configs"]. + + Returns: + An engine initializer. + """ + if isinstance(config, str): + config_path = Path(config) + if not config_path.exists(): + raise FileNotFoundError(f"Engine config file {config} not found.") + with open(config_path, "r") as filep: + config = yaml.safe_load(filep) + + assert isinstance(config, dict) + + # Override configs + if override is not None: + update_dict_recursive(config, override) + + # Ray runtime environments. + runtime_env: Dict[str, Any] = config.get("runtime_env", {}) + ray_env_vars: Dict[str, Any] = runtime_env.get("env_vars", {}) + + # Download model and save to local path in advance, in case + # too many worker downloads the model in parallel and hit huggingface rate limit. + assert "model_id" in config and isinstance(config["model_id"], str) + if ray_env_vars.pop("PREDOWNLOAD_MODEL_FROM_HF", "0") == "1": + config["model_id"] = download_model_from_hf( + config["model_id"], "/mnt/cluster_storage" + ) + + # Do not download LoRA adapter here because it is not used in the driver node. + lora_adapter = None + if "lora_config" in config: + lora_adapter = config["lora_config"].get("dynamic_lora_loading_path", None) + + # Sanity check for engine kwargs. + for key in ("llm_engine", "model_id", "accelerator_type"): + if key not in config: + raise KeyError(f"Required {key} not found in config.") + if "engine_kwargs" not in config: + config["engine_kwargs"] = {} + + name = config["llm_engine"] + if name == "vllm": + return vLLMEngineInitializer( + model_id=config["model_id"], + accelerator_type=config["accelerator_type"], + engine_kwargs=config["engine_kwargs"], + lora_adapter=lora_adapter, + ray_env_vars=ray_env_vars, + ) + + raise ValueError(f"Unknown engine: {name}") diff --git a/skythought/skythought_evals/batch/engines/vllm_engine.py b/skythought/skythought_evals/batch/engines/vllm_engine.py new file mode 100644 index 0000000..9e2c255 --- /dev/null +++ b/skythought/skythought_evals/batch/engines/vllm_engine.py @@ -0,0 +1,444 @@ +"""The vLLM engine.""" + +import asyncio +import dataclasses +import math +import os +import sys +import time +import uuid +from dataclasses import dataclass +from enum import Enum +from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union + +import msgspec +import numpy as np +import ray +from packaging import version +from vllm import AsyncEngineArgs, AsyncLLMEngine, PoolingParams, SamplingParams +from vllm.inputs.data import TextPrompt, TokensPrompt +from vllm.lora.request import LoRARequest +from vllm.outputs import PoolingRequestOutput, RequestOutput + +from ..logging import get_logger +from ..utils import ( + async_caller_empty_batch_handler, + maybe_download_model_from_s3, + wait_for_gpu_memory_to_clear, +) +from .base import EngineBase + +logger = get_logger(__name__) + + +@dataclass(frozen=True) +class LLMRequest: + """A request to the LLM wrapper.""" + + # Index in the batch. + idx_in_batch: int + # The request ID for the LLM engine (unique per replica). + request_id: int + # The full prompt string (with chat template applied if any). + prompt: str + # The tokenized prompt IDs. If None, then the string prompt will be + # tokenized by the LLM engine. This is not recommended for performance reasons. + prompt_token_ids: Optional[List[int]] + # The sampling or pooling parameters. + params: Union[SamplingParams, PoolingParams] + # Custom data to be passed through to the output. + custom_data: Dict[str, Any] + # (optional) LoRA adapter. + lora_request: Optional[LoRARequest] = None + + +class AsyncLLMWrapper: + """Wrapper around the vLLM engine to handle async requests. + + Args: + *args: The positional arguments for the engine. + max_pending_requests: The maximum number of pending requests in the queue. + **kwargs: The keyword arguments for the engine. + """ + + def __init__(self, *args, max_pending_requests: int = -1, **kwargs): + engine_args = AsyncEngineArgs( + *args, + **kwargs, + disable_log_requests=True, + ) + self.engine = AsyncLLMEngine.from_engine_args(engine_args) + self.max_pending_requests = max_pending_requests + + # Determine the generate function based on vLLM v0 or v1. + if os.getenv("VLLM_USE_V1", "0") == "1": + self._generate_async = self.generate_async_v1 + else: + self._generate_async = self.generate_async_v0 + + # FIXME: The asyncio queue crashes in Python 3.9 and Ray 2.37- with the following error: + # got Future attached to a different loop, because Ray Data + # creates a new event loop. This can be removed when + # https://github.com/ray-project/ray/issues/47734 is released. + if ( + version.parse(ray.__version__) <= version.parse("2.37.0") + and sys.version_info.minor == 9 + ): + if self.max_pending_requests > 0: + logger.warning( + "max_pending_requests is disabled due to a known issue with asyncio " + "in Python 3.9 with Ray 2.37-" + ) + self.max_pending_requests = 0 + + # vLLM performance gets really bad if there are too many requests in the pending queue. + # We work around it by introducing another queue that gates how many requests we are + # sending to vLLM at once. + # This is not a queue of requests - instead, this queue holds "slots". Each time + # we add a new request, we take one slot. Each time a request finishes, we add a new + # slot. + self.free_queue: asyncio.Queue[bool] = asyncio.Queue() + if self.max_pending_requests > 0: + for _ in range(self.max_pending_requests): + self.free_queue.put_nowait(True) + + async def generate_async( + self, request: LLMRequest + ) -> Tuple[LLMRequest, RequestOutput]: + """Process a single request. + + Args: + request: The request. + + Returns: + A tuple of index in batch, request output and bypassed custom fields. + """ + # If free queue is used, guard the request here until a slot is available. + if self.max_pending_requests > 0: + await self.free_queue.get() + + ret = await self._generate_async(request) + + # If free queue is used, release the slot. + if self.max_pending_requests > 0: + self.free_queue.put_nowait(True) + + return ret + + async def generate_async_v0( + self, request: LLMRequest + ) -> Tuple[LLMRequest, RequestOutput]: + """Process a single request. + + Args: + request: The request. + + Returns: + A tuple of index in batch, request output and bypassed custom fields. + """ + if request.prompt_token_ids is not None: + llm_prompt = TokensPrompt(prompt_token_ids=request.prompt_token_ids) + else: + assert request.prompt + llm_prompt = TextPrompt(prompt=request.prompt) + + # Send the request to the LLM engine. + stream = await self.engine.add_request( + request_id=str(request.request_id), + prompt=llm_prompt, + params=request.params, + lora_request=request.lora_request, + ) + # Consume the stream until the request is finished. + async for request_output in stream: + if request_output.finished: + # Bypass the original full prompt. + request_output.prompt = request.prompt + return (request, request_output) + raise RuntimeError("Should not reach here") + + async def generate_async_v1( + self, request: LLMRequest + ) -> Tuple[LLMRequest, RequestOutput]: + """Process a single request. + + Args: + request: The request. + + Returns: + A tuple of index in batch, request output and bypassed custom fields. + """ + # NOTE: vLLM v1 tighly couples tokenizer and detokenizer to the engine, + # so we should set tokenize=False in .run() to avoid redundant tokenization + # for better performance (although the impact should be minimal). + assert request.prompt + llm_prompt = TextPrompt(prompt=request.prompt) + + # Send the request to the LLM engine. + stream = self.engine.generate( + request_id=str(request.request_id), + prompt=llm_prompt, + sampling_params=request.params, + lora_request=request.lora_request, + ) + + # Consume the stream until the request is finished. + async for request_output in stream: + if request_output.finished: + # Bypass the original full prompt. + request_output.prompt = request.prompt + return (request, request_output) + + raise RuntimeError("Should not reach here") + + +class AsyncLLMPredictor(EngineBase): + """Async LLM predictor. + + Args: + model: The model name. + params: The sampling or pooling parameters. + lora_adapter: The LoRA adapter. + max_pending_requests: The maximum number of pending requests. + **kwargs: The keyword arguments for the engine. + """ + + def __init__( + self, + model: str, + params: Union[SamplingParams, PoolingParams], + lora_adapter: Optional[str] = None, + max_pending_requests: Optional[int] = None, + **kwargs, + ): + # Sanity check. + for key in ( + "enable_prefix_caching", + "enforce_eager", + "pipeline_parallel_size", + "tensor_parallel_size", + "max_num_seqs", + ): + assert key in kwargs, f"[InternalError] {key} not found in engine_kwargs." + + # Download model from S3 if needed. + model = maybe_download_model_from_s3(model) + # Download LoRA adapter from S3 if needed. + if lora_adapter is not None: + lora_adapter = maybe_download_model_from_s3(lora_adapter) + + wait_for_gpu_memory_to_clear(1000 * 2**20) + self.request_id = 0 + self.enable_prefix_caching = kwargs["enable_prefix_caching"] + self.params = params + self.lora_request = ( + LoRARequest("adapter", 1, lora_adapter) + if lora_adapter is not None + else None + ) + if self.lora_request is not None: + logger.info("LoRA adapter is enabled: %s", lora_adapter) + # Enforce enable_lora=True in the engine kwargs + kwargs["enable_lora"] = True + + # Set max_logprobs to the maximum of sampling_logprobs and sampling_prompt_logprobs. + if isinstance(params, SamplingParams): + sampling_logprobs = params.logprobs or 0 + sampling_prompt_logprobs = params.prompt_logprobs or 0 + kwargs["max_logprobs"] = max(sampling_logprobs, sampling_prompt_logprobs) + + attn_backend = os.getenv("VLLM_ATTENTION_BACKEND", "FLASH_ATTN") + if attn_backend == "FLASHINFER" and self.enable_prefix_caching: + if kwargs["kv_cache_dtype"].startswith("fp8"): + # FlashInfer does not support bfloat16 activations. + kwargs["dtype"] = "float16" + + self.enforce_eager = kwargs["enforce_eager"] + pp_size = kwargs["pipeline_parallel_size"] + self.max_pending_requests = max_pending_requests or math.ceil( + kwargs["max_num_seqs"] * pp_size * 1.1 + ) + if self.max_pending_requests > 0: + logger.info("Max pending requests is set to %d", self.max_pending_requests) + + # Create an LLM. + self.llm = AsyncLLMWrapper( + model=model, + disable_log_stats=False, + max_pending_requests=self.max_pending_requests, + **kwargs, + ) + + def _prepare_llm_input(self, batch: Dict[str, Any]) -> List[LLMRequest]: + """Prepare the inputs for LLM inference. + + Args: + batch: The batch. + + Returns: + A list of LLMRequest. + """ + if "prompt" not in batch: + raise ValueError( + "Required 'prompt' not found in batch. This may be " + "due to an unknown internal error if your workload needs " + "tokenization. If your workload does not need tokenization, " + "please make sure 'prompt' exists in the dataset." + ) + prompt = batch.pop("prompt").tolist() + + if "tokenized_prompt" in batch: + tokenized_prompt = batch.pop("tokenized_prompt").tolist() + else: + tokenized_prompt = [None] * len(prompt) + + # If sampling_params is provided in the batch, override the default. + if "sampling_params" in batch: + sampling_params_dict = batch.pop("sampling_params").tolist() + params = [SamplingParams(**s) for s in sampling_params_dict] + else: + params = [self.params] * len(prompt) + + # Rest fields are custom data. + keys, values = list(batch.keys()), zip(*batch.values()) + custom_data = [dict(zip(keys, v)) for v in values] + + # Organize data to be LLM requests. + requests = [] + for idx, (p, pt, sp, cd) in enumerate( + zip(prompt, tokenized_prompt, params, custom_data) + ): + requests.append( + LLMRequest( + idx_in_batch=idx, + request_id=self.request_id, + prompt=p, + prompt_token_ids=pt, + params=sp, + custom_data=cd, + lora_request=self.lora_request, + ) + ) + self.request_id += 1 + return requests + + def _parse_llm_output( + self, output: Union[RequestOutput, PoolingRequestOutput] + ) -> Dict[str, Any]: + """Parse the LLM output. + + Args: + output: The LLM output. + + Returns: + The parsed output. + """ + # Parse the common fields. + output_data = { + "prompt": [output.prompt], + "prompt_token_ids": [output.prompt_token_ids], + "num_input_tokens": [len(output.prompt_token_ids)], + "request_id": [output.request_id], + } + + if isinstance(output, RequestOutput): + metrics = {} + if output.metrics is not None: + metrics = { + k: [v] for k, v in dataclasses.asdict(output.metrics).items() + } + generated_tokens = [ + output.outputs[i].token_ids for i in range(len(output.outputs)) + ] + num_generated_tokens = [ + len(output.outputs[i].token_ids) for i in range(len(output.outputs)) + ] + output_data.update( + { + "generated_tokens": ( + [generated_tokens] + if len(generated_tokens) > 1 + else generated_tokens + ), + "num_generated_tokens": ( + [num_generated_tokens] + if len(num_generated_tokens) > 1 + else num_generated_tokens + ), + **metrics, + } + ) + elif isinstance(output, PoolingRequestOutput): + output_data.update( + { + "embeddings": [output.outputs.data.cpu()], + } + ) + else: + raise ValueError(f"Unknown output type: {type(output)}") + + return output_data + + async def call_async( + self, batch: Dict[str, np.ndarray] + ) -> AsyncGenerator[Dict[str, Any], None]: + """Call the LLM asynchronously to process a batch. + + Args: + batch: The batch. + + Yields: + The output. + """ + batch_uuid = uuid.uuid4() + t = time.perf_counter() + + requests = self._prepare_llm_input(batch) + tasks = [ + asyncio.create_task(self.llm.generate_async(request)) + for request in requests + ] + + time_taken = -1.0 + for resp in asyncio.as_completed(tasks): + request, output = await resp + time_taken = time.perf_counter() - t + index_in_batch = request.idx_in_batch + param_dict = msgspec.structs.asdict(request.params) + # Convert RequestOutputKind (Enum) to integer value. + if "output_kind" in param_dict and isinstance( + param_dict["output_kind"], Enum + ): + param_dict["output_kind"] = param_dict["output_kind"].value + custom_data = request.custom_data + custom_data = {k: [v] for k, v in custom_data.items()} + + yield { + **self._parse_llm_output(output), + "batch_uuid": [batch_uuid.hex], + "time_taken_llm": [time_taken], + "index_in_batch": [index_in_batch], + "params": [param_dict], + **custom_data, + } + logger.info( + "[vLLM] Elapsed time for batch %s with size %d: %s", + batch_uuid.hex, + len(requests), + time_taken, + ) + + @async_caller_empty_batch_handler + async def __call__( + self, batch: Dict[str, np.ndarray] + ) -> AsyncGenerator[Dict[str, Any], None]: + """Call the LLM asynchronously to process a batch. + + Args: + batch: The batch. + + Yields: + The output. + """ + async for x in self.call_async(batch): + yield x diff --git a/skythought/skythought_evals/batch/env_config.py b/skythought/skythought_evals/batch/env_config.py new file mode 100644 index 0000000..7a5a3bd --- /dev/null +++ b/skythought/skythought_evals/batch/env_config.py @@ -0,0 +1,41 @@ +"""Environment configurations for Ray.""" + +from dataclasses import dataclass +from typing import Dict, Optional + +from .logging import get_logger + +logger = get_logger(__name__) + + +@dataclass +class EnvConfig: + """Environment configurations for Ray.""" + + # General configurations. + hf_token: Optional[str] = None + ray_override_job_runtime_env: str = "1" + + # Ray Data configurations. + ray_data_default_wait_for_min_actors_s: int = 600 + + # The number of LLM engine replicas to use. + num_replicas: int = 1 + # The batch size. This represents the unit of fault tolerance. + # Smaller batch size implies more fault tolerance but may + # introduce more overhead. Batch size should at least be 16 to + # avoid hanging. + batch_size: int = 256 + + def gen_ray_runtime_envs(self, engine_envs: Dict[str, str]) -> Dict[str, str]: + """Generate Ray runtime environment variables.""" + envs = {k.upper(): str(v) for k, v in engine_envs.items()} + + for key in ( + "hf_token", + "ray_override_job_runtime_env", + "ray_data_default_wait_for_min_actors_s", + ): + if getattr(self, key) is not None: + envs[key.upper()] = str(getattr(self, key)) + return envs diff --git a/skythought/skythought_evals/batch/logging/__init__.py b/skythought/skythought_evals/batch/logging/__init__.py new file mode 100644 index 0000000..f67378e --- /dev/null +++ b/skythought/skythought_evals/batch/logging/__init__.py @@ -0,0 +1,55 @@ +"""Logging.""" + +import logging +from typing import Optional + +from ray._private.ray_logging.filters import CoreContextFilter +from ray._private.ray_logging.formatters import JSONFormatter + + +def _add_ray_logging(handler: logging.Handler): + """Add Ray logging to the handler. + + This is not used for now and will be enabled after the Ray Job is supported. + + Args: + handler: The handler to add Ray logging to. + """ + handler.addFilter(CoreContextFilter()) + handler.setFormatter(JSONFormatter()) + + +def _setup_logger(logger_name: str): + """Setup logger given the logger name. + + This function is idempotent and won't set up the same logger multiple times. + + Args: + logger_name: The name of the logger. + """ + logger = logging.getLogger(logger_name) + + # Skip setup if the logger already has handlers setup. + if logger.handlers: + return + + handler = logging.StreamHandler() + logger.addHandler(handler) + logger.setLevel(logging.INFO) + logger.propagate = False + + +def get_logger(name: Optional[str] = None) -> logging.Logger: + """Get a structured logger. + + Loggers by default are logging to stdout, and are expected to be scraped by an + external process. + + Args: + name: The name of the logger. + + Returns: + A logger instance. + """ + _setup_logger(name) + return logging.getLogger(name) diff --git a/skythought/skythought_evals/batch/pipeline.py b/skythought/skythought_evals/batch/pipeline.py new file mode 100644 index 0000000..b86faaa --- /dev/null +++ b/skythought/skythought_evals/batch/pipeline.py @@ -0,0 +1,281 @@ +"""Pipeline for batch processing large-scale LLM workloads.""" + +import os +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +import ray +from ray.data._internal.stats import DatasetStats +from ray.data.dataset import Dataset +from ray.util import remove_placement_group +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +from .engines import EngineInitializerBase, init_engine_from_config +from .env_config import EnvConfig +from .logging import get_logger +from .tokenizer import Detokenizer +from .workload import EvalWorkload + +if TYPE_CHECKING: + from ray.util.placement_group import PlacementGroup + +logger = get_logger(__name__) + + +class Pipeline: + """Pipeline for batch processing large-scale LLM workloads. + + Args: + engine_initializer: An engine initializer to create and initialize an engine. + workload: Workload instance. + env_config: EnvConfig to provide environment configurations of Ray. + """ + + def __init__( + self, + engine_initializer: EngineInitializerBase, + env_config: EnvConfig, + ): + self.engine_initializer = engine_initializer + self.env_config = env_config + self.num_replicas: int = self.env_config.num_replicas + self.ds: Optional[Dataset] = None + self.stats: Optional[DatasetStats] = None + + self.pgs: List["PlacementGroup"] = [] + + if not ray.is_initialized(): + ray.init(runtime_env={"env_vars": self.env_vars}) + + @classmethod + def from_config( + cls, engine_cfg: Union[Dict[str, Any], str], workload: EvalWorkload, **kwargs + ): + """Initialize the pipeline from a configuration file or dictionary. + + Args: + engine_cfg: A config file (in YAML) or a config dict. It should include + the following keys: "engine", backend engine to use; "model", + model to use; "accelerator_type", the GPU type; "configs", + the engine specific configurations. + workload: Workload instance. + **kwargs: environment configuration parameters. See `EnvConfig` for more details. + """ + engine_initializer = init_engine_from_config(engine_cfg) + env_config = EnvConfig(**kwargs) + return cls(engine_initializer, workload, env_config) + + @property + def env_vars(self) -> Dict[str, Any]: + return self.env_config.gen_ray_runtime_envs( + self.engine_initializer.ray_env_vars + ) + + def load( + self, + repartition_by_batch_size: bool = False, + ) -> Dataset: + """Use the given workload to load and process the dataset, + and then tokenize the prompts if needed. The processed dataset + will be repartitioned based on the number of replicas and batch size. + + Args: + repartition_by_batch_size: Whether to repartition the dataset by the + batch size for fault tolerance granularity. You should enable + this when the dataset is not from parquet and checkpointing is + disabled. + + Returns: + The processed dataset. + """ + ds, num_blocks = self.workload.get_preprocessed_dataset( + self.env_config.batch_size, + repartition_by_batch_size, + ) + if num_blocks is not None and num_blocks < self.num_replicas: + logger.warning( + "The number of blocks (%d) is less than the number of replicas (%d). " + "This may result in suboptimal performance.", + num_blocks, + self.num_replicas, + ) + + if self.workload.need_tokenize: + # TODO: Figure out a better concurrency. + # Now we simply assume each LLM replica could have 4 tokenizers. + # This is a heuristic and may not be optimal. + tokenizer_concurrency = self.num_replicas * 4 + ds = ds.map_batches( + self.workload.tokenizer_cls, + fn_constructor_kwargs=self.workload.tokenizer_constructor_kwargs( + self.engine_initializer.model + ), + zero_copy_batch=True, + concurrency=(1, tokenizer_concurrency), + batch_size=self.env_config.batch_size, + ) + + # If max tokens in prompt is not set in the workload and max_model_len is not set + # in the engine, we need to materialize the dataset to get the maximum tokens in prompt. + # This may hurt the overall throughput but may be memory efficient. + if self.workload.max_tokens_in_prompt == -1: + if self.engine_initializer.max_model_len is not None: + max_tokens = self.workload.sampling_params.get("max_tokens", 0) + max_tokens_in_prompt = ( + self.engine_initializer.max_model_len - max_tokens + ) + msg = f"Max Prompt Tokens (max_model_len - max_tokens): {max_tokens_in_prompt}" + else: + logger.info( + "Materializing dataset after tokenization to get max prompt tokens" + ) + ds = ds.materialize() + + max_tokens_in_prompt = int(ds.max("num_text_tokens")) + msg = f"Max Prompt Tokens (inferred): {max_tokens_in_prompt}" + self.workload.max_tokens_in_prompt = max_tokens_in_prompt + else: + msg = f"Max Prompt Tokens (specified in wokrload): {self.workload.max_tokens_in_prompt}" + + logger.info(msg) + self.ds = ds + return ds + + def __call__(self, workload: EvalWorkload): + self.workload: EvalWorkload = workload + # Set the task to "embed" if sampling params are not given. + self.task_type_str: str = ( + "auto" if self.workload.sampling_params is not None else "embed" + ) + return self.run(eager=False) + + def run( + self, + dataset: Optional[Dataset] = None, + output_path: Optional[str] = None, + detokenize: bool = True, + eager: bool = True, + repartition_by_batch_size: bool = False, + ) -> Optional[Dataset]: + """Perform batch processing on the dataset with LLM engines. + + Args: + dataset: The dataset to process. If None, we directly use the given workload + to load and process the dataset. + output_path: The output path to write the processed dataset to parquet. It can be + a path to a S3 bucket, or a path to local disk (with local:// as the prefix). If None, + the processed dataset will be materialized but not be written. + detokenize: Whether to detokenize the generated text. Default is True. + eager: Whether to run the pipeline eagerly. If True, the dataset will be materialized. + If False, we skip the materialization step and return the dataset. If output_path is specified, + the dataset will be written to files and therefore will be materialized + regardless of the eager flag. + repartition_by_batch_size: Whether to repartition the dataset by the + batch size for fault tolerance granularity. You should enable + this when the dataset is not from parquet and checkpointing is + disabled. + + Returns: + The processed dataset. If output_path is not None, the dataset will be None after writing. + """ + if not eager and output_path is not None: + logger.warning("Eager mode is enforced because output path is specified") + eager = True + + # Expend output_path in case environment variable is used. + if output_path is not None: + output_path = os.path.expanduser(output_path) + + # Force skipping detokenizer if task is "embed". + if self.task_type_str == "embed" and detokenize: + logger.info("Detokenization is skipped because of embedding workload") + detokenize = False + + ray_remote_args = {} + if self.engine_initializer.accelerator_type: + ray_remote_args["accelerator_type"] = ( + self.engine_initializer.accelerator_type + ) + ray_remote_args.update({"runtime_env": {"env_vars": self.env_vars}}) + + if dataset is not None: + self.ds = dataset + elif self.ds is None: + self.load(repartition_by_batch_size) + assert self.ds is not None + + num_gpus = self.engine_initializer.num_gpus + if self.engine_initializer.use_ray_placement_group: + # Specify the number of GPUs required per LLM instance. + # Note: for TP>1, num_gpus has to be 0 - instead, we specify a placement group + if self.engine_initializer.num_gpus > 1: + + def _scheduling_strategy_fn( + num_gpus_per_instance: int, accelerator_type: str + ): + def _get_bundle() -> Dict[str, float]: + bundle: Dict[str, float] = {"GPU": 1, "CPU": 1} + if accelerator_type: + bundle[f"accelerator_type:{accelerator_type}"] = 0.001 + return bundle + + pg = ray.util.placement_group( + [_get_bundle()] * num_gpus_per_instance, + strategy="STRICT_PACK", + ) + self.pgs.append(pg) + return dict( + scheduling_strategy=PlacementGroupSchedulingStrategy( + pg, placement_group_capture_child_tasks=True + ) + ) + + ray_remote_args.update( + _scheduling_strategy_fn( + self.engine_initializer.num_gpus, + self.engine_initializer.accelerator_type, + ) + ) + + self.ds = self.ds.map_batches( + self.engine_initializer.get_engine_cls(), + fn_constructor_kwargs=self.engine_initializer.get_engine_constructor_args( + self.workload + ), + zero_copy_batch=True, + # The number of running actors. + concurrency=self.env_config.num_replicas, + # The number of running batches for an actor in Ray Core level. + # The value may not be optimal when the batch size is too small, + # but it should be good enough for batch size >= 64. + max_concurrency=4, + batch_size=self.env_config.batch_size, + num_gpus=num_gpus, + **ray_remote_args, + ) + + # Skip detokenization. Usually used for tuning, profiling, and embedding. + if detokenize: + self.ds = self.ds.map_batches( + Detokenizer, + fn_constructor_kwargs={"model": self.engine_initializer.model}, + zero_copy_batch=True, + concurrency=(1, self.num_replicas), + batch_size=self.env_config.batch_size, + ) + + if output_path is not None: + # Dataset will become None after writing to parquet. + self.ds = self.ds.write_parquet(output_path) + elif eager: + self.ds = self.ds.materialize() + + # If the dataset pipeline is executed due to eager mode, we can cleanup. + if eager: + self.cleanup() + + return self.ds + + def cleanup(self): + for pg in self.pgs: + remove_placement_group(pg) + self.pgs.clear() diff --git a/skythought/skythought_evals/batch/tokenizer.py b/skythought/skythought_evals/batch/tokenizer.py new file mode 100644 index 0000000..a680a02 --- /dev/null +++ b/skythought/skythought_evals/batch/tokenizer.py @@ -0,0 +1,180 @@ +"""Tokenizer and detokenizer for LLMs.""" + +import time +from typing import Any, AsyncGenerator, Dict, Union + +import numpy as np +from transformers import ( + AutoProcessor, + AutoTokenizer, + PreTrainedTokenizer, # type: ignore + PreTrainedTokenizerFast, +) + +from .logging import get_logger +from .utils import async_caller_empty_batch_handler, maybe_download_model_from_s3 + +AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast, Any] + +logger = get_logger(__name__) + + +def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: + """Get tokenizer with cached properties. + + This will patch the tokenizer object in place. + By default, transformers will recompute multiple tokenizer properties + each time they are called, leading to a significant slowdown. This + function caches these properties for faster access. + + Args: + tokenizer: The tokenizer object. + + Returns: + The patched tokenizer object. + """ + chat_template = getattr(tokenizer, "chat_template", None) + # For VLM, the text tokenizer is wrapped by a processor. + if hasattr(tokenizer, "tokenizer"): + tokenizer = tokenizer.tokenizer + # Some VLM's tokenizer has chat_template attribute (e.g. Qwen/Qwen2-VL-7B-Instruct), + # however some other VLM's tokenizer does not have chat_template attribute (e.g. + # mistral-community/pixtral-12b). Therefore, we cache the processor's chat_template. + if chat_template is None: + chat_template = getattr(tokenizer, "chat_template", None) + + tokenizer_all_special_ids = set(tokenizer.all_special_ids) + tokenizer_all_special_tokens_extended = tokenizer.all_special_tokens_extended + tokenizer_all_special_tokens = set(tokenizer.all_special_tokens) + tokenizer_len = len(tokenizer) + + class CachedTokenizer(tokenizer.__class__): # type: ignore + @property + def all_special_ids(self): + return tokenizer_all_special_ids + + @property + def all_special_tokens(self): + return tokenizer_all_special_tokens + + @property + def all_special_tokens_extended(self): + return tokenizer_all_special_tokens_extended + + @property + def chat_template(self): + return chat_template + + def __len__(self): + return tokenizer_len + + CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}" + + tokenizer.__class__ = CachedTokenizer + return tokenizer + + +class ChatTemplateTokenizer: + """Tokenizer with chat template applied. + + Args: + model: The model name. + """ + + def __init__(self, model: str) -> None: + self.model = maybe_download_model_from_s3(model) + self.tokenizer = get_cached_tokenizer(AutoProcessor.from_pretrained(self.model)) + + @async_caller_empty_batch_handler + async def __call__( + self, batch: Dict[str, np.ndarray] + ) -> AsyncGenerator[Dict[str, Any], None]: + """Call the tokenizer to process a batch. + This function first process inputs in the batch asynchronously to apply + chat template because this step cannot be batched. Then it tokenizes all inputs at once. + + Args: + batch: The batch. + + Yields: + The output. + """ + if "messages" not in batch: + raise KeyError(f'"messages" not found in {batch.keys()=}') + + start_t = time.perf_counter() + messages = batch["messages"].tolist() + + # Tokenize text prompts. + full_prompts = [ + self.tokenizer.apply_chat_template( + message, tokenize=False, add_generation_prompt=True + ) + for message in messages + ] + tokens = self.tokenizer(full_prompts)["input_ids"] + time_taken_tokenizer = time.perf_counter() - start_t + + ret = { + **batch, + "prompt": full_prompts, + "tokenized_prompt": tokens, + "num_text_tokens": [len(t) for t in tokens], + "time_taken_tokenizer": [time_taken_tokenizer] * len(tokens), + } + + yield ret + + +class Detokenizer: + """Detokenizer for LLMs. + + Args: + model: The model name. + """ + + def __init__(self, model: str) -> None: + self.model = maybe_download_model_from_s3(model) + self.tokenizer = get_cached_tokenizer(AutoTokenizer.from_pretrained(self.model)) + + async def __call__( + self, batch: Dict[str, np.ndarray] + ) -> AsyncGenerator[Dict[str, Any], None]: + """Detokenize the batch. + + Args: + batch: The batch data. + + Returns: + The detokenized batch. + """ + start_t = time.perf_counter() + generated_tokens = batch["generated_tokens"] + flattened = False + # if the generated tokens are nested lists, flatten them + if isinstance(generated_tokens[0][0], np.ndarray): + # flatten the lists of lists for detokenization + flattened = True + generated_tokens = [ + token for tokens in generated_tokens for token in tokens + ] # flattens list + generated_text = self.tokenizer.batch_decode( + generated_tokens, skip_special_tokens=True + ) + if flattened: + # unflatten the list back to original structure + curr_idx = 0 + generated_text_unflattened = [] + for sublist in batch["generated_tokens"]: + sublist_len = len(sublist) + generated_text_unflattened.append( + generated_text[curr_idx : curr_idx + sublist_len] + ) + curr_idx += sublist_len + generated_text = generated_text_unflattened + time_taken_detokenizer = time.perf_counter() - start_t + yield { + **batch, + "generated_text": generated_text, + "time_taken_detokenizer": [time_taken_detokenizer] * len(generated_text), + } diff --git a/skythought/skythought_evals/batch/utils.py b/skythought/skythought_evals/batch/utils.py new file mode 100644 index 0000000..1c57ef1 --- /dev/null +++ b/skythought/skythought_evals/batch/utils.py @@ -0,0 +1,269 @@ +"""Utility functions""" + +import os +import subprocess +import time +from functools import wraps +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional + +import pyarrow +import ray +from filelock import FileLock +from huggingface_hub import snapshot_download +from pynvml import nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo, nvmlInit # type: ignore +from ray.data import Dataset + +from .logging import get_logger + +logger = get_logger(__name__) + + +# The default local root directory to store models downloaded from S3. +# This path should always available on Anyscale platform. If not, then +# we will fallback to FALLBACK_LOCAL_MODEL_ROOT. +DEFAULT_LOCAL_MODEL_ROOT = "/mnt/local_storage/cache" +FALLBACK_LOCAL_MODEL_ROOT = "/tmp/cache" + + +def update_dict_recursive( + orig: Dict[str, Any], update_dict: Dict[str, Any] +) -> Dict[str, Any]: + """Update a dictionary (in-place) recursively. + + Args: + orig: The original dictionary. + update_dict: The dictionary to update. + + Returns: + The updated dictionary. + """ + for key, value in update_dict.items(): + if isinstance(value, dict): + orig[key] = update_dict_recursive(orig.get(key, {}), value) + else: + orig[key] = value + return orig + + +def wait_for_gpu_memory_to_clear(threshold_bytes: int, timeout_s: float = 120) -> None: + """Wait for GPU memory to be below a threshold. + Use nvml instead of pytorch to reduce measurement error from torch cuda context. + + Args: + threshold_bytes: The threshold in bytes. + timeout_s: The timeout in seconds. + + Raises: + ValueError: If the memory is not free after the timeout. + """ + devices = [int(x) for x in ray.get_gpu_ids()] + nvmlInit() + start_time = time.monotonic() + while True: + output = {} + output_raw = {} + for device in devices: + dev_handle = nvmlDeviceGetHandleByIndex(device) + mem_info = nvmlDeviceGetMemoryInfo(dev_handle) + gb_used = mem_info.used / 2**30 + output_raw[device] = gb_used + output[device] = f"{gb_used:.02f}" + + logger.info( + "GPU memory used (GB): " + "; ".join(f"{k}={v}" for k, v in output.items()) + ) + + dur_s = time.monotonic() - start_time + if all(v <= (threshold_bytes / 2**30) for v in output_raw.values()): + logger.info( + "Done waiting for free GPU memory on devices %s (%.2f GB) %.02f s", + devices, + threshold_bytes / 2**30, + dur_s, + ) + break + + if dur_s >= timeout_s: + raise ValueError( + f"Memory of devices {devices=} not free after " + f"{dur_s=:.02f} ({threshold_bytes/2**30=})" + ) + + time.sleep(5) + + +def run_s3_command(command: List[str], error_msg: Optional[str] = None) -> Any: + """Run a S3 command and raise an exception if it fails. + + Args: + command: The command to run. + error_msg: The error message to raise if the command fails. + + Returns: + The result of the command. + """ + try: + return subprocess.run(command, check=True, capture_output=True) + except Exception as err: + # Not using logger.exception since we raise anyway. + if isinstance(err, (subprocess.TimeoutExpired, subprocess.CalledProcessError)): + stdout_txt = f"\nSTDOUT: {err.stdout.decode()}" if err.stdout else "" + stderr_txt = f"\nSTDERR: {err.stderr.decode()}" if err.stderr else "" + else: + stdout_txt = "" + stderr_txt = "" + + if error_msg is not None: + logger.error( + "(%s) %s. Command %s.%s%s", + str(err), + error_msg, + command, + stdout_txt, + stderr_txt, + ) + raise + + +def download_hf_model_from_s3(s3_path: str, local_path_root: str) -> str: + """Download model files from s3 to the local path. The model path prefix + will be added to the local path. + + Args: + s3_path: The s3 path to download from. + local_path_root: The local path root to download to. + + Returns: + The local path where the files are downloaded. + """ + if not s3_path.startswith("s3://"): + raise ValueError(f"Invalid s3 path: {s3_path}") + + prefix = "/".join(s3_path.split("/")[3:]) + local_path = Path(local_path_root) / prefix + + # Use aws s3 sync to make sure we don't download the same files again. + command = ["aws", "s3", "sync", s3_path, local_path] + + logger.info( + "Downloading %s to %s using %s", + s3_path, + local_path, + command, + ) + with FileLock(local_path / ".lock", timeout=-1): + run_s3_command(command, f"Failed to sync model from {s3_path} to {local_path}") + return str(local_path) + + +def maybe_download_model_from_s3( + model_path: str, local_path_root: Optional[str] = None +) -> str: + """Download model from s3 to the local path, and return the local model path. + + Args: + model_path: The maybe s3 path to download from. + lora_path_root: The local path root to download to. If not provided, + will use the default path (/mnt/local_storage/cache or /tmp/cache). + + Returns: + The local path where the model is downloaded. + """ + s3_path = os.path.expandvars(model_path) + if not s3_path.startswith("s3://"): + return model_path + + local_root = Path(local_path_root or DEFAULT_LOCAL_MODEL_ROOT) + try: + local_root.mkdir(parents=True, exist_ok=True) + # Check if the directory is writable. + with open(local_root / ".test", "w") as fp: + fp.write("test") + except PermissionError: + logger.warning( + "Failed to create local root directory at %s (Permission denied). " + "Reset local root to %s", + local_root, + FALLBACK_LOCAL_MODEL_ROOT, + ) + local_root = Path(FALLBACK_LOCAL_MODEL_ROOT) + local_root.mkdir(parents=True, exist_ok=True) + + return download_hf_model_from_s3(s3_path, local_root) + + +def download_model_from_hf( + model_name: str, local_path_root: Optional[str] = None +) -> str: + """Download model files from Hugging Face to the local path. + If the local path has permission issues, return the original model name, but warn the user. + + Args: + model_name: The model name to download. + local_path_root: The local path root to download to. If not provided, + will use the default path (/mnt/local_storage/cache or /tmp/cache + + Returns: + The local path where the files are downloaded. + """ + # If the model_name is already a local path, skip downloading + if model_name.startswith("/"): + return model_name + + local_model_path = Path(local_path_root or DEFAULT_LOCAL_MODEL_ROOT) / model_name + try: + local_model_path.mkdir(parents=True, exist_ok=True) + + # Check directory is writable by trying to list files (avoiding .test file creation) + if not os.access(local_model_path, os.W_OK): + raise PermissionError + except PermissionError: + logger.warning( + "Failed to create or write to the model directory at %s (Permission denied). " + "Please grant permission, or each worker may download the model, hitting rate limits.", + local_model_path, + ) + return model_name # Return the original model name + + snapshot_download(repo_id=model_name, local_dir=str(local_model_path)) + + return str(local_model_path) + + +def async_caller_empty_batch_handler(func) -> Callable: + """A decorator to handle the case where all rows are checkpointed. + When all rows are checkpointed, we will still get a batch + in pyarrow.Table format with empty rows. This is a bug and + is being tracked here: + https://github.com/anyscale/rayturbo/issues/1292 + + Args: + func: The function to wrap. + + Returns: + The wrapped function. + """ + + @wraps(func) + async def wrapper(self, batch): + if not isinstance(batch, pyarrow.lib.Table) or batch.num_rows > 0: + async for x in func(self, batch): + yield x + else: + yield {} + + return wrapper + + +def has_materialized(ds: Dataset) -> bool: + """Check if the dataset has been materialized. + TODO: This API should be moved to Ray Data. + + Args: + ds: The dataset to check. + + Returns: + True if the dataset is materialized, False otherwise. + """ + return bool(ds.stats()) diff --git a/skythought/skythought_evals/batch/workload.py b/skythought/skythought_evals/batch/workload.py new file mode 100644 index 0000000..d4f35da --- /dev/null +++ b/skythought/skythought_evals/batch/workload.py @@ -0,0 +1,159 @@ +"""The workload.""" + +import math +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Dict, Optional, Tuple + +import yaml +from ray.data.dataset import Dataset + +from .logging import get_logger +from .tokenizer import ChatTemplateTokenizer + +logger = get_logger(__name__) + + +def load_config_from_path(config_path: str) -> Dict[str, Any]: + if isinstance(config_path, str): + config_path = Path(config_path) + if not config_path.exists(): + raise FileNotFoundError(f"Engine config file {config_path} not found.") + with open(config_path, "r") as filep: + config = yaml.safe_load(filep) + + assert isinstance(config, dict) + return config + + +@dataclass +class EvalWorkload: + # The ray.data.Dataset. If None, the Worklod must initialize the dataset + # in __post_init__(). + dataset: Optional[Dataset] + # Sampling a fraction of dataset for benchmarking and testing. If the value + # is greater than one, it means to take the first N rows from the dataset. + dataset_fraction: float = 1.0 + # Tokenizer class for the workload. + tokenizer_cls: Any = ChatTemplateTokenizer + + # Sampling parameters for the workload, such as max_tokens, temperature, etc. + # It can only be None when the workload is used for embedding. + sampling_params: Dict[str, Any] = field( + default_factory=lambda: {"max_tokens": 4096} + ) + # Pooling parameters for the workload, such as pooling_type, etc. + # It can only be None when the workload is used for auto-regressive generation. + pooling_params: Optional[Dict[str, Any]] = None + + need_tokenize: bool = True + # When specified, the tokenization will be async because we don't need to + # materialize an entire tokenized dataset to get the maximum tokens in prompt. + # With the default value of -1, the actual value will be set after tokenization. + max_tokens_in_prompt: int = -1 + + # Do we want to carry over input keys that are not in the output? + carryover_inputs: bool = True + + def validate(self): + if not ((self.sampling_params is None) ^ (self.pooling_params is None)): + raise ValueError( + "Either sampling_params or pooling_params must be specified." + ) + + def get_preprocessed_dataset( + self, + max_batch_size: int = 256, + repartition_by_batch_size: bool = False, + ) -> Tuple[Dataset, Optional[int]]: + """Load the dataset and process it. + + Args: + max_batch_size: The batch size. This determines the number of rows per + block. Note that if some rows have already processed (checkpointed), + the actual batch size may be smaller than this value. + repartition_by_batch_size: Whether to repartition the dataset by the + batch size for fault tolerance granularity. You should enable + this when the dataset is not from parquet and checkpointing is + disabled. + + Returns: + The processed dataset and the number of blocks. If checkpointing is + enabled, then the number of blocks is unknown. + """ + self.validate() + if self.dataset is None: + raise ValueError( + "dataset must be specified or initialized before calling " + "get_preprocessed_dataset()." + ) + + self.max_batch_size = max_batch_size + + ds = self.dataset + if self.dataset_fraction < 1.0: + logger.info("Sampling %f dataset", self.dataset_fraction) + ds = ds.random_sample(self.dataset_fraction, seed=0) + elif self.dataset_fraction > 1.0: + n_rows = int(self.dataset_fraction) + logger.info("Taking the first %d rows from dataset", n_rows) + ds = ds.limit(n_rows) + + if repartition_by_batch_size: + num_requests = ds.count() + num_blocks = math.ceil(num_requests / max_batch_size) + ds = ds.repartition(num_blocks) + + logger.info("#Requests: %d (%d blocks)", num_requests, num_blocks) + else: + # When checkpointing is enabled, the number of blocks is unknown + # at this point. + num_blocks = None + + mapper_fn = ( + self.parse_row_with_carryover_input + if self.carryover_inputs + else self.parse_row + ) + return ds.map(mapper_fn), num_blocks + + def tokenizer_constructor_kwargs(self, model: str): + """Return the keyword arguments for tokenizer constructor. + + Args: + model: The model name. + + Returns: + The keyword arguments for tokenizer constructor. + """ + return {"model": model} + + def parse_row_with_carryover_input(self, row: dict[str, Any]) -> dict[str, Any]: + """Same as parse_row but carries over the input keys that are not in the output row. + + This is useful when we want to keep the input keys in the output. + This method assumes if user returns the same output keys as + input keys they have already copied input over and there is + no need to do it again for those keys. We will just copy the input_keys that + are not in the output row. + + Args: + row: The row to be parsed. + + Returns: + The parsed row. + """ + input_row_keys = set(row.keys()) + output_row = self.parse_row(row) + output_row_keys = set(output_row.keys()) + return { + **{k: row[k] for k in input_row_keys if k not in output_row_keys}, + **output_row, + } + + def parse_row(self, row: Dict[str, Any]) -> Dict[str, Any]: + """Parse each row in the dataset to make them compatible with + OpenAI chat API messages. Specifically, the output row should only + include a single key "messages" with type Dict[str, Union[str, List[Dict]]]. + """ + return {"messages": row["item"][1], "index": row["item"][0]} diff --git a/skythought/skythought_evals/inference_and_check.py b/skythought/skythought_evals/inference_and_check.py index 9d779c0..524da95 100644 --- a/skythought/skythought_evals/inference_and_check.py +++ b/skythought/skythought_evals/inference_and_check.py @@ -1,11 +1,17 @@ import argparse import concurrent.futures +import copy import json +import math import os from concurrent.futures import ProcessPoolExecutor, as_completed from functools import partial import numpy as np +import ray +from batch import Pipeline, init_engine_from_config +from batch.env_config import EnvConfig +from batch.workload import EvalWorkload, load_config_from_path from openai import OpenAI from skythought_evals.models import ModelConfig, get_system_prompt_keys from skythought_evals.tasks import ( @@ -16,6 +22,7 @@ TaskHandler, ) from skythought_evals.util.common import set_seed +from skythought_evals.util.response import Response from tqdm import tqdm from vllm import LLM, SamplingParams @@ -55,6 +62,59 @@ def fetch_response_openai(llm, model_name, max_tokens, temp, prompt): return response +def fetch_responses_ray(conversations, max_tokens, temp, args): + config = load_config_from_path(args.ray_config) + config["model_id"] = args.model + engine_cfg = init_engine_from_config(config) + ds = ray.data.from_items([(idx, conv) for idx, conv in enumerate(conversations)]) + num_replicas = config["env_config"].get("num_replicas", 1) + if ds.count() < config["env_config"].get("batch_size", 1): + config["env_config"]["batch_size"] = math.ceil(ds.count() / num_replicas) + if num_replicas > 1 and num_replicas > ds.num_blocks(): + ds = ds.repartition(num_partitions=num_replicas) + workload = EvalWorkload( + dataset=ds, + sampling_params={"n": args.n, "max_tokens": max_tokens, "temperature": temp}, + ) + pipeline = Pipeline( + engine_cfg, + env_config=EnvConfig(**config["env_config"]), + ) + ds = pipeline(workload) + responses = ds.materialize() + return responses + + +def inference(llm, conversations, max_tokens, temp, args): + if args.use_ray: + responses = fetch_responses_ray(conversations, max_tokens, temp, args) + responses = [ + Response.from_ray_response(response) for response in responses.iter_rows() + ] + # TODO/NOTE: This deepcopy is needed to avoid a SIGSEV error related to object cleanup with the ray object store and + # the later use of ProcessPoolExecutor - see here: https://github.com/NovaSky-AI/SkyThought/pull/63#discussion_r1941899714 + # revisit the underlying issue and remove the deepcopy if possible + responses = copy.deepcopy(responses) + responses = sorted(responses, key=lambda x: x.index) + elif args.model.startswith("openai"): + fetch_partial = partial( + fetch_response_openai, llm, args.model, max_tokens, temp + ) + + with concurrent.futures.ThreadPoolExecutor(max_workers=16) as e: + responses = list(e.map(fetch_partial, conversations)) + + responses = [Response.from_openai_response(response) for response in responses] + else: + sampling_params = SamplingParams(max_tokens=max_tokens, temperature=temp) + responses = llm.chat( + messages=conversations, sampling_params=sampling_params, use_tqdm=True + ) + responses = [Response.from_vllm_response(response) for response in responses] + + return responses + + def perform_inference_and_check( handler: TaskHandler, temperatures, @@ -64,6 +124,7 @@ def perform_inference_and_check( model_config, args, ): + assert args.n == 1, "Check does not support multiple samples" results = handler.load_existing_results(result_file) print(f"Loaded {len(results)} existing results.") train_data = handler.load_and_filter_dataset( @@ -79,48 +140,28 @@ def perform_inference_and_check( remaining_data, model_config.system_prompt, model_config.user_template ) for temp in temperatures: - if args.model.startswith("openai"): - fetch_partial = partial( - fetch_response_openai, llm, args.model, max_tokens, temp - ) + if len(conversations) == 0: + print("No more data to process") + continue - with concurrent.futures.ThreadPoolExecutor(max_workers=16) as e: - responses = list(e.map(fetch_partial, conversations)) - - else: - sampling_params = SamplingParams(max_tokens=max_tokens, temperature=temp) - responses = llm.chat( - messages=conversations, sampling_params=sampling_params, use_tqdm=True - ) + responses = inference(llm, conversations, max_tokens, temp, args) total_correct = 0 total_finish = 0 with ProcessPoolExecutor(max_workers=32) as executor: - # future_to_task = { - # executor.submit(handler.update_results, remaining_data[idx], response): idx - # for idx, response in enumerate(responses) - # } future_to_task = {} token_usages = {} for idx, response in enumerate(responses): - if args.model.startswith("openai"): - response_str = response.choices[0].message.content.strip() - else: - response_str = response.outputs[0].text.strip() + response_str = response.response.strip() future_to_task[ executor.submit( handler.update_results, remaining_data[idx], response_str ) ] = idx - # print(f"Request output: {response}") - - if args.model.startswith("openai"): - token_usages[idx] = response.usage - else: - token_usages[idx] = { - "completion_tokens": len(response.outputs[0].token_ids), - "prompt_tokens": len(response.prompt_token_ids), - } + token_usages[idx] = { + "completion_tokens": response.num_completion_tokens, + "prompt_tokens": response.num_input_tokens, + } for future in tqdm( as_completed(future_to_task), @@ -145,14 +186,7 @@ def perform_inference_and_check( results[problem_key]["responses"][str(temp)] = response_entry - if args.model.startswith("openai"): - results[problem_key]["token_usages"][str(temp)] = { - "completion_tokens": token_usages[idx].completion_tokens, - "prompt_tokens": token_usages[idx].prompt_tokens, - } - else: - # TODO: vLLM model, can it do the same thing - results[problem_key]["token_usages"][str(temp)] = token_usages[idx] + results[problem_key]["token_usages"][str(temp)] = token_usages[idx] print(f"Final acc: {total_correct}/{total_finish}") acc = round(total_correct / total_finish, 4) if total_finish > 0 else 0 @@ -323,21 +357,10 @@ def perform_inference_and_save( ) for temp in temperatures: - if args.model.startswith("openai"): - fetch_partial = partial( - fetch_response_openai, llm, args.model, max_tokens, temp - ) - - with concurrent.futures.ThreadPoolExecutor(max_workers=16) as e: - responses = list(e.map(fetch_partial, conversations)) - - else: - sampling_params = SamplingParams( - n=args.n, max_tokens=max_tokens, temperature=temp - ) - responses = llm.chat( - messages=conversations, sampling_params=sampling_params, use_tqdm=True - ) + if len(conversations) == 0: + print("No more data to process") + continue + responses = inference(llm, conversations, max_tokens, temp, args) completion_tokens = [] prompt_tokens = [] @@ -346,28 +369,36 @@ def perform_inference_and_save( token_usages = [] completion_token = 0 for sample_idx in range(args.n): + if args.model.startswith("openai"): + content = response.response.strip() + else: + content = response.response[sample_idx].strip() response_entry = { - "content": ( - response.choices[0].message.content.strip() - if args.model.startswith("openai") - else response.outputs[sample_idx].text.strip() - ), + "content": content, "correctness": None, "reason": None, } response_entries.append(response_entry) - if not args.model.startswith("openai"): + if args.model.startswith("openai"): + token_usages.append( + { + "completion_tokens": response.num_completion_tokens, + "prompt_tokens": response.num_input_tokens, + } + ) + else: token_usages.append( { - "completion_tokens": len( - response.outputs[sample_idx].token_ids - ), - "prompt_tokens": len(response.prompt_token_ids), + "completion_tokens": response.num_completion_tokens[ + sample_idx + ], + "prompt_tokens": response.num_input_tokens, } ) - completion_token += len(response.outputs[sample_idx].token_ids) + completion_token += response.num_completion_tokens[sample_idx] + completion_token /= args.n - prompt_token = len(response.prompt_token_ids) + prompt_token = response.num_input_tokens prompt_tokens.append(prompt_token) completion_tokens.append(completion_token) @@ -385,13 +416,7 @@ def perform_inference_and_save( results[problem_key]["responses"][str(temp)] = response_entries - if args.model.startswith("openai"): - results[problem_key]["token_usages"][str(temp)] = { - "completion_tokens": response.usage.completion_tokens, - "prompt_tokens": response.usage.prompt_tokens, - } - else: - results[problem_key]["token_usages"][str(temp)] = token_usages + results[problem_key]["token_usages"][str(temp)] = token_usages # Token usage summary put into another subdirectory result_dir, result_name = os.path.split(result_file) @@ -511,10 +536,23 @@ def main(): "--n", type=int, default=1, help="Number of samples generated per problem." ) parser.add_argument("--seed", type=int, default=41, help="Random seed.") + parser.add_argument( + "--use_ray", action="store_true", help="Use ray for scaling inference." + ) + parser.add_argument( + "--ray_config", + type=str, + default="ray_configs/ray_config.yaml", + help="Ray configuration file if using ray for scaling inference.", + ) args = parser.parse_args() set_seed(args.seed) + # use os to enable hf_transfer for model download + if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", None) not in ["1", "True"]: + os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" + if args.task not in TASK_NAMES_TO_YAML: raise ValueError( f"Task {args.task} not found. Should be one of {TASK_NAMES_TO_YAML.keys()}" @@ -576,32 +614,23 @@ def main(): result_file = converted_file perform_check(handler, temperatures, result_file, args) return - elif args.inference: - llm = ( - OpenAI() - if args.model.startswith("openai") - else LLM(model=args.model, tensor_parallel_size=args.tp) - ) - perform_inference_and_save( - handler, temperatures, max_tokens, result_file, llm, model_config, args - ) - return - - llm = ( - OpenAI() - if args.model.startswith("openai") - else LLM(model=args.model, tensor_parallel_size=args.tp) - ) - - perform_inference_and_check( - handler, - temperatures, - max_tokens, - result_file, - llm, - model_config, - args, - ) + else: + if args.use_ray: + llm = None + else: + llm = ( + OpenAI() + if args.model.startswith("openai") + else LLM(model=args.model, tensor_parallel_size=args.tp) + ) + if args.inference: + perform_inference_and_save( + handler, temperatures, max_tokens, result_file, llm, model_config, args + ) + else: + perform_inference_and_check( + handler, temperatures, max_tokens, result_file, llm, model_config, args + ) if __name__ == "__main__": diff --git a/skythought/skythought_evals/ray_configs/ray_config.yaml b/skythought/skythought_evals/ray_configs/ray_config.yaml new file mode 100644 index 0000000..e56a183 --- /dev/null +++ b/skythought/skythought_evals/ray_configs/ray_config.yaml @@ -0,0 +1,23 @@ +llm_engine: vllm # currently only vllm supported +accelerator_type: H100 # accelerator name as specified here: https://docs.ray.io/en/master/ray-core/accelerator-types.html#accelerator-types +engine_kwargs: # vllm engine kwargs + tensor_parallel_size: 1 + gpu_memory_utilization: 0.9 + # other optional vllm engine kwargs to tune performance! + # pipeline_parallel_size: 1 + # max_num_seqs: 448 + # use_v2_block_manager: True + # enable_prefix_caching: False + # preemption_mode: "recompute" + # block_size: 16 + # kv_cache_dtype: "auto" + # enforce_eager: False + # enable_chunked_prefill: True + # max_num_batched_tokens: 8192 + # max_seq_len_to_capture: 32768 +runtime_env: + env_vars: + VLLM_ATTENTION_BACKEND: "FLASH_ATTN" +env_config: + num_replicas: 8 # number of vllm replicas + batch_size: 128 # ray pipeline internal batch size (used for map_batches call internally). Should usually be set to a value in [64, 128, 256] for best performance. diff --git a/skythought/skythought_evals/requirements.txt b/skythought/skythought_evals/requirements.txt index d8d5712..928ee18 100644 --- a/skythought/skythought_evals/requirements.txt +++ b/skythought/skythought_evals/requirements.txt @@ -5,4 +5,4 @@ scipy datasets latex2sympy2 pydantic -setuptools \ No newline at end of file +setuptools diff --git a/skythought/skythought_evals/util/response.py b/skythought/skythought_evals/util/response.py new file mode 100644 index 0000000..f13b4b4 --- /dev/null +++ b/skythought/skythought_evals/util/response.py @@ -0,0 +1,81 @@ +from dataclasses import dataclass +from typing import Optional, Union + + +@dataclass +class Response: + response: Union[str, list[str]] + num_completion_tokens: Union[int, list[int]] + num_input_tokens: int + index: Optional[int] = None + + @classmethod + def from_ray_response(cls, response) -> "Response": + """ + Factory method to create a Response instance from a rayllm response. + + Args: + response: Ray response object containing generated text and token information + + Returns: + Responses: New instance initialized with Ray response data + """ + if isinstance(response["generated_text"], list): + # n > 1 samples + num_completion_tokens = [ + int(response["num_generated_tokens"][i]) + for i in range(len(response["num_generated_tokens"])) + ] + else: + num_completion_tokens = int(response["num_generated_tokens"]) + return cls( + response=response["generated_text"], + num_completion_tokens=num_completion_tokens, + num_input_tokens=int(response["num_input_tokens"]), + index=response["index"], + ) + + @classmethod + def from_openai_response(cls, response) -> "Response": + """ + Factory method to create a Response instance from an OpenAI response. + + Args: + response: OpenAI response object containing message content and token information + + Returns: + Responses: New instance initialized with OpenAI response data + """ + # TODO: allow for multiple samples + return cls( + response=response.choices[0].message.content, + num_completion_tokens=response.usage.completion_tokens, + num_input_tokens=response.usage.prompt_tokens, + ) + + @classmethod + def from_vllm_response(cls, response) -> "Response": + """ + Factory method to create a Response instance from a vLLM response. + + Args: + response: vLLM response object containing output text and token information + + Returns: + Responses: New instance initialized with vLLM response data + """ + response_text = ( + [response.outputs[i].text for i in range(len(response.outputs))] + if len(response.outputs) > 1 + else response.outputs[0].text + ) + num_completion_tokens = ( + [len(s) for s in response_text] + if not isinstance(response_text, str) + else len(response_text) + ) + return cls( + response=response_text, + num_completion_tokens=num_completion_tokens, + num_input_tokens=len(response.prompt_token_ids), + )