Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[evals] Add support for scaling evals and inference with ray #63

Merged
merged 33 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
bf56116
add rayllm batch path
erictang000 Jan 28, 2025
fe0d518
fix typo
erictang000 Jan 28, 2025
94e9009
temp make mmlu pro smaller
erictang000 Jan 30, 2025
88ad60f
update vllm version, rayllm config, add repartition
erictang000 Jan 31, 2025
efc6a01
Add submodule repo
erictang000 Jan 31, 2025
1fadec9
move evalworkload code to outside module
erictang000 Jan 31, 2025
cf6a9fd
Update submodule to latest commit
erictang000 Jan 31, 2025
ed759e7
remove [:4000] for mmlupro
erictang000 Jan 31, 2025
6429201
updates to inference_and_save path
erictang000 Feb 1, 2025
7d2f39e
fix small issues
erictang000 Feb 1, 2025
91a7d5a
disable n > 1 for inference and save rayllm path
erictang000 Feb 1, 2025
ea7694a
Merge branch 'rayllm' of https://github.com/erictang000/SkyThought in…
erictang000 Feb 1, 2025
f69406a
inference_and_save works for n = 1 use_rayllm
erictang000 Feb 1, 2025
461a43c
Remove submodule
erictang000 Feb 1, 2025
f47dcfa
remove submodule and rename to pipeline
erictang000 Feb 1, 2025
7a0ee7b
remove unnecessary model_id
erictang000 Feb 1, 2025
ec1b192
remove .gitmodules
erictang000 Feb 1, 2025
224bbab
rename main to pipeline
erictang000 Feb 1, 2025
d419fc3
add support for n > 1
erictang000 Feb 2, 2025
a92dfe4
remove old code
erictang000 Feb 2, 2025
bd27b8d
fix unflatten logic for n > 1
erictang000 Feb 2, 2025
972e6e5
merge
erictang000 Feb 3, 2025
00bfc5b
finish merge stuff
erictang000 Feb 3, 2025
dd615ae
split small datasets
erictang000 Feb 3, 2025
e02d0d9
address some comments (add response object and add separate inference…
erictang000 Feb 4, 2025
c37cf6d
small comment use_ray
erictang000 Feb 4, 2025
4b15c25
resolve some more comments
erictang000 Feb 4, 2025
9ef28f9
reduce workload
erictang000 Feb 4, 2025
50e79a0
changes
erictang000 Feb 4, 2025
cc7dd8b
fix ProcessPoolExecutor + response ray sigsev bug
erictang000 Feb 4, 2025
1b6843e
Merge branch 'main' of https://github.com/erictang000/SkyThought into…
erictang000 Feb 4, 2025
2045247
add comment
erictang000 Feb 4, 2025
6615a78
Merge branch 'main' of https://github.com/NovaSky-AI/SkyThought into …
SumanthRH Feb 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ cython_debug/
# Vim
*.swp

.json
*.json
token_usage/

run_all.sh
Empty file modified format.sh
100644 → 100755
Empty file.
13 changes: 13 additions & 0 deletions skythought/skythought_evals/batch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
__all__ = []

from .engines import init_engine_from_config
from .pipeline import Pipeline
from .workload import (
EvalWorkload,
)

__all__ = [
"Pipeline",
"init_engine_from_config",
"EvalWorkload",
]
10 changes: 10 additions & 0 deletions skythought/skythought_evals/batch/engines/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""LLM Engines."""

__all__ = []

from .initializer import EngineInitializerBase, init_engine_from_config

__all__ = [
"EngineInitializerBase",
"init_engine_from_config",
]
22 changes: 22 additions & 0 deletions skythought/skythought_evals/batch/engines/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Engine base."""

from typing import Any, AsyncGenerator, Dict

import numpy as np


class EngineBase:
"""Base class for engines."""

async def __call__(
self, batch: Dict[str, np.ndarray]
) -> AsyncGenerator[Dict[str, Any], None]:
"""Call the LLM engine asynchronously to process a Ray Data batch.

Args:
batch: The batch.

Yields:
The output.
"""
raise NotImplementedError
264 changes: 264 additions & 0 deletions skythought/skythought_evals/batch/engines/initializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,264 @@
"""Engine initializers.
Note that this file should not import any engine dependent modeules, such as
vLLM, because the engine initializer is used in the driver node which may
not have GPUs.
"""

import os
from pathlib import Path
from typing import Any, Dict, Optional, Union

import yaml

from ..utils import (
download_model_from_hf,
update_dict_recursive,
)
from ..workload import EvalWorkload
from .base import EngineBase


class EngineInitializerBase:
"""Base class for engine initializer.

Args:
model_id: The model id.
accelerator_type: The accelerator type.
engine_kwargs: The engine specific configurations.
ray_env_vars: The Ray runtime environment
"""

use_ray_placement_group: bool = False

def __init__(
self,
model_id: str,
accelerator_type: str,
engine_kwargs: Dict[str, Any],
lora_adapter: Optional[str] = None,
ray_env_vars: Dict[str, Any] = None,
):
self._model = model_id
self._accelerator_type = accelerator_type
self._ray_env_vars = ray_env_vars or {}
self.lora_adapter = lora_adapter
self.engine_kwargs = engine_kwargs

@property
def model(self) -> str:
return self._model

@property
def accelerator_type(self) -> str:
return self._accelerator_type

@property
def ray_env_vars(self) -> Dict[str, str]:
return self._ray_env_vars

@property
def num_gpus(self) -> int:
"""The number of GPUs used per engine."""
raise NotImplementedError

@property
def max_model_len(self) -> Optional[int]:
"""The maximum model length set by the engine."""
return None

def get_engine_cls(self) -> EngineBase:
"""Get the engine class.

Returns:
The engine class.
"""
raise NotImplementedError

def get_engine_constructor_args(self, workload: EvalWorkload) -> Dict[str, Any]:
"""Get the engine constructor arguments.

Args:
workload: The workload that the engine will process.

Returns:
The engine constructor keyword arguments.
"""
raise NotImplementedError


class vLLMEngineInitializer(EngineInitializerBase):
use_ray_placement_group: bool = False

def __init__(
self,
model_id: str,
accelerator_type: str,
engine_kwargs: Dict[str, Any],
lora_adapter: Optional[str] = None,
ray_env_vars: Dict[str, Any] = None,
):
super().__init__(
model_id, accelerator_type, engine_kwargs, lora_adapter, ray_env_vars
)

# Override vLLM default configs. Note that this is only effective
# when the config is not set by users.
self.engine_kwargs.setdefault("gpu_memory_utilization", 0.95)
self.engine_kwargs.setdefault("use_v2_block_manager", True)
self.engine_kwargs.setdefault("enable_prefix_caching", False)
self.engine_kwargs.setdefault("enforce_eager", False)
self.engine_kwargs.setdefault("pipeline_parallel_size", 1)
self.engine_kwargs.setdefault("max_num_seqs", 256)
self.engine_kwargs.setdefault("tensor_parallel_size", 1)
self.engine_kwargs.setdefault("max_logprobs", 0)
self.engine_kwargs.setdefault("distributed_executor_backend", "mp")

# Set engine environment variables.
self._ray_env_vars.setdefault("VLLM_ATTENTION_BACKEND", "FLASH_ATTN")
self._ray_env_vars.setdefault("ENABLE_ANYSCALE_PREFIX_OPTIMIZATIONS", "0")
# FIXME: This should already be deprecated and can be removed.
self._ray_env_vars.setdefault("VLLM_DISABLE_LOGPROBS", "1")
for key, value in self._ray_env_vars.items():
os.environ[key] = str(value)

def get_engine_cls(self):
from .vllm_engine import AsyncLLMPredictor

return AsyncLLMPredictor

@property
def num_gpus(self) -> int:
assert "tensor_parallel_size" in self.engine_kwargs
assert "pipeline_parallel_size" in self.engine_kwargs
tp_size = self.engine_kwargs["tensor_parallel_size"]
pp_size = self.engine_kwargs["pipeline_parallel_size"]
return tp_size * pp_size

@property
def max_model_len(self) -> Optional[int]:
"""The maximum model length set by the engine."""
return self.engine_kwargs.get("max_model_len", None)

def get_engine_constructor_args(self, workload: EvalWorkload):
from vllm import PoolingParams, SamplingParams
from vllm.config import PoolerConfig

constructor_kwargs = {
"model": self.model,
"lora_adapter": self.lora_adapter,
}

if sampling_params := workload.sampling_params:
# Sampling params is given: Auto-regressive generation.
# In this case, we need to set max_tokens and max_model_len.

max_tokens = sampling_params.get("max_tokens", None)
if max_tokens is None:
raise ValueError("max_tokens is required for vLLM engine.")

vllm_sampling_params = SamplingParams(**workload.sampling_params)
vllm_sampling_params.max_tokens = max_tokens
vllm_sampling_params.detokenize = False
constructor_kwargs["params"] = vllm_sampling_params

if (
"max_model_len" not in self.engine_kwargs
and workload.max_tokens_in_prompt < 0
):
raise ValueError(
"Neither max_tokens_in_prompt nor max_model_len is set. If you "
"intend to let the pipeline infer max_tokens_in_prompt but got this error, "
"it is either because the workload has not been tokenized, or the "
"workload bypass the tokenizer but does not set max_tokens_in_prompt by itself."
)

# Use max_tokens_in_prompt + max_tokens as the max_model_len. max_tokens_in_prompt
# is either inferred by materializing tokenized dataset, set by the workload, or
# set by the engine.
self.engine_kwargs["max_model_len"] = (
workload.max_tokens_in_prompt + max_tokens
)
else:
# Sampling params is not given: Embedding workload.
# In this case, we need to set pooling_params and task.

if workload.pooling_params is None:
raise ValueError(
"pooling_params is required for vLLM engine for embedding workload."
)
constructor_kwargs["params"] = PoolingParams(**workload.pooling_params)
constructor_kwargs["task"] = "embed"

# Construct PoolerConfig if override_pooler_config is specified.
if pooler_config := self.engine_kwargs.get("override_pooler_config", None):
self.engine_kwargs["override_pooler_config"] = PoolerConfig(
**pooler_config
)

constructor_kwargs.update(self.engine_kwargs)
return constructor_kwargs


def init_engine_from_config(
config: Union[Dict[str, Any], str], override: Optional[Dict[str, Any]] = None
) -> EngineInitializerBase:
"""Initialize an engine initializer from a config file or a config dict.

Args:
config: A config file (in YAML) or a config dict. It should include
the following keys: "engine", backend engine to use; "model",
model to use; "accelerator_type", the GPU type; "configs",
the engine specific configurations.
override: Override values in config["configs"].

Returns:
An engine initializer.
"""
if isinstance(config, str):
config_path = Path(config)
if not config_path.exists():
raise FileNotFoundError(f"Engine config file {config} not found.")
with open(config_path, "r") as filep:
config = yaml.safe_load(filep)

assert isinstance(config, dict)

# Override configs
if override is not None:
update_dict_recursive(config, override)

# Ray runtime environments.
runtime_env: Dict[str, Any] = config.get("runtime_env", {})
ray_env_vars: Dict[str, Any] = runtime_env.get("env_vars", {})

# Download model and save to local path in advance, in case
# too many worker downloads the model in parallel and hit huggingface rate limit.
assert "model_id" in config and isinstance(config["model_id"], str)
if ray_env_vars.pop("PREDOWNLOAD_MODEL_FROM_HF", "0") == "1":
config["model_id"] = download_model_from_hf(
config["model_id"], "/mnt/cluster_storage"
)

# Do not download LoRA adapter here because it is not used in the driver node.
lora_adapter = None
if "lora_config" in config:
lora_adapter = config["lora_config"].get("dynamic_lora_loading_path", None)

# Sanity check for engine kwargs.
for key in ("llm_engine", "model_id", "accelerator_type"):
if key not in config:
raise KeyError(f"Required {key} not found in config.")
if "engine_kwargs" not in config:
config["engine_kwargs"] = {}

name = config["llm_engine"]
if name == "vllm":
return vLLMEngineInitializer(
model_id=config["model_id"],
accelerator_type=config["accelerator_type"],
engine_kwargs=config["engine_kwargs"],
lora_adapter=lora_adapter,
ray_env_vars=ray_env_vars,
)

raise ValueError(f"Unknown engine: {name}")
Loading