Skip to content

Commit

Permalink
reduce workload
Browse files Browse the repository at this point in the history
  • Loading branch information
erictang000 committed Feb 4, 2025
1 parent 4b15c25 commit 9ef28f9
Showing 1 changed file with 8 additions and 158 deletions.
166 changes: 8 additions & 158 deletions skythought/skythought_evals/batch/workload.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,19 @@
"""The workload."""

import math
import os
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, Tuple

import ray
from ray.data.dataset import Dataset

try:
from ray.anyscale.data.checkpoint.interfaces import (
CheckpointBackend,
CheckpointConfig,
)
except ImportError:
CheckpointConfig = None
CheckpointBackend = None

from pathlib import Path
from typing import Any, Dict, Optional, Tuple

import yaml
from ray.data.dataset import Dataset

from .logging import get_logger
from .tokenizer import ChatTemplateTokenizer
from .utils import has_materialized

logger = get_logger(__name__)


def read_parquet(path: str, max_batch_size: int) -> Dataset:
"""Read the parquet file and return the dataset.
Different from ray.data.read_parquet, this function controls
the concurrency and the number of blocks. This is for the following
reasons:
- We have to set concurrency explicitly; otherwise
read_parquet by default asking for all CPU cores to run the task,
which leads insufficient resources for other stages.
- The number of blocks determines the fault tolerance granularity.
Args:
path: The path to the parquet file.
max_batch_size: The maximum batch size per block.
Returns:
The dataset.
"""
ds = ray.data.read_parquet(
path,
concurrency=2,
)
# The .count() would be fast only if the dataset is read from
# parquet and we are using RayTurbo 2.39+. In this case Ray Data
# only reads the metadata and does not materialize the dataset.
num_rows = ds.count()
num_blocks = math.ceil(num_rows / max_batch_size)

return ray.data.read_parquet(
path,
concurrency=2,
override_num_blocks=num_blocks,
)


def load_config_from_path(config_path: str) -> Dict[str, Any]:
if isinstance(config_path, str):
config_path = Path(config_path)
Expand All @@ -75,28 +27,26 @@ def load_config_from_path(config_path: str) -> Dict[str, Any]:


@dataclass
class WorkloadBase:
"""The base class for a workload."""

class EvalWorkload:
# The ray.data.Dataset. If None, the Worklod must initialize the dataset
# in __post_init__().
dataset: Optional[Dataset]
# Sampling a fraction of dataset for benchmarking and testing. If the value
# is greater than one, it means to take the first N rows from the dataset.
dataset_fraction: float
dataset_fraction: float = 1.0
# Tokenizer class for the workload.
tokenizer_cls: Any
tokenizer_cls: Any = ChatTemplateTokenizer

# Sampling parameters for the workload, such as max_tokens, temperature, etc.
# It can only be None when the workload is used for embedding.
sampling_params: Optional[Dict[str, Any]] = None
sampling_params: Dict[str, Any] = field(
default_factory=lambda: {"max_tokens": 4096}
)
# Pooling parameters for the workload, such as pooling_type, etc.
# It can only be None when the workload is used for auto-regressive generation.
pooling_params: Optional[Dict[str, Any]] = None

need_tokenize: bool = True
# The column name to be used as the checkpoint ID.
ckpt_col_name: Optional[str] = None
# When specified, the tokenization will be async because we don't need to
# materialize an entire tokenized dataset to get the maximum tokens in prompt.
# With the default value of -1, the actual value will be set after tokenization.
Expand All @@ -114,7 +64,6 @@ def validate(self):
def get_preprocessed_dataset(
self,
max_batch_size: int = 256,
ckpt_path: Optional[str] = None,
repartition_by_batch_size: bool = False,
) -> Tuple[Dataset, Optional[int]]:
"""Load the dataset and process it.
Expand All @@ -123,7 +72,6 @@ def get_preprocessed_dataset(
max_batch_size: The batch size. This determines the number of rows per
block. Note that if some rows have already processed (checkpointed),
the actual batch size may be smaller than this value.
ckpt_path: The path to the checkpoint directory.
repartition_by_batch_size: Whether to repartition the dataset by the
batch size for fault tolerance granularity. You should enable
this when the dataset is not from parquet and checkpointing is
Expand All @@ -142,52 +90,7 @@ def get_preprocessed_dataset(

self.max_batch_size = max_batch_size

# Setup checkpointing.
enable_ckpt = False
if self.ckpt_col_name is not None:
if ckpt_path is not None:
if CheckpointConfig is None:
logger.warning(
"Checkpoint is not available. Please install Ray Turbo 2.39+"
)
else:
enable_ckpt = True

if has_materialized(self.dataset):
curr_ds_stats = self.dataset.stats()
raise RuntimeError(
"The checkpointing is specified in this workload, but "
"the dataset has already been materialized so checkpointing "
"cannot be applied. Please either disable checkpointing "
"or avoid to materialize the dataset before running this "
f"workload. Current dataset stats: {curr_ds_stats}"
)

# Determine the checkpoint backend.
ckpt_path = os.path.expandvars(ckpt_path)
if ckpt_path.startswith("s3://"):
ckpt_backend = CheckpointBackend.S3
else:
ckpt_backend = CheckpointBackend.DISK

ckpt_config = CheckpointConfig(
enabled=True,
backend=ckpt_backend,
id_col=self.ckpt_col_name,
output_path=ckpt_path,
)
ray.data.DataContext.get_current().checkpoint_config = ckpt_config
else:
logger.warning(
"This workload specifies checkpoint column name but checkpoint "
"path is not provided, so checkpointing will be disabled."
)

ds = self.dataset
if enable_ckpt and self.dataset_fraction != 1.0:
raise ValueError(
"Cannot process dataset fraction when checkpoint is enabled"
)
if self.dataset_fraction < 1.0:
logger.info("Sampling %f dataset", self.dataset_fraction)
ds = ds.random_sample(self.dataset_fraction, seed=0)
Expand All @@ -197,11 +100,6 @@ def get_preprocessed_dataset(
ds = ds.limit(n_rows)

if repartition_by_batch_size:
if enable_ckpt:
raise ValueError(
"Cannot repartition the dataset by batch size when "
"checkpointing is enabled."
)
num_requests = ds.count()
num_blocks = math.ceil(num_requests / max_batch_size)
ds = ds.repartition(num_blocks)
Expand Down Expand Up @@ -242,54 +140,6 @@ def parse_row_with_carryover_input(self, row: dict[str, Any]) -> dict[str, Any]:
**output_row,
}

def tokenizer_constructor_kwargs(self, model: str):
"""Return the keyword arguments for tokenizer constructor.
Args:
model: The model name.
Returns:
The keyword arguments for tokenizer constructor.
"""
return {"model": model}

def parse_row(self, row: dict[str, Any]) -> dict[str, Any]:
"""Parse a row in the dataset.
Args:
row: The row in the dataset.
Returns:
The parsed row.
"""
return row

def postproc_after_tokenize(self, ds: Dataset) -> Dataset:
"""Post-process the dataset after tokenization.
Args:
ds: The tokenized dataset.
Returns:
The post-processed dataset.
"""
return ds


@dataclass
class ChatWorkloadBase(WorkloadBase):
"""The base class for a chat workload."""

tokenizer_cls: Any = ChatTemplateTokenizer


@dataclass
class EvalWorkload(ChatWorkloadBase):
dataset_fraction: float = 1.0
sampling_params: Dict[str, Any] = field(
default_factory=lambda: {"max_tokens": 4096}
)

def parse_row(self, row: Dict[str, Any]) -> Dict[str, Any]:
"""Parse each row in the dataset to make them compatible with
OpenAI chat API messages. Specifically, the output row should only
Expand Down

0 comments on commit 9ef28f9

Please sign in to comment.