Skip to content

Commit

Permalink
changes
Browse files Browse the repository at this point in the history
  • Loading branch information
erictang000 committed Feb 4, 2025
1 parent 9ef28f9 commit 50e79a0
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 22 deletions.
4 changes: 2 additions & 2 deletions skythought/skythought_evals/batch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from .engines import init_engine_from_config
from .pipeline import Pipeline
from .workload import (
ChatWorkloadBase,
EvalWorkload,
)

__all__ = [
"Pipeline",
"init_engine_from_config",
"ChatWorkloadBase",
"EvalWorkload",
]
6 changes: 3 additions & 3 deletions skythought/skythought_evals/batch/engines/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
download_model_from_hf,
update_dict_recursive,
)
from ..workload import WorkloadBase
from ..workload import EvalWorkload
from .base import EngineBase


Expand Down Expand Up @@ -74,7 +74,7 @@ def get_engine_cls(self) -> EngineBase:
"""
raise NotImplementedError

def get_engine_constructor_args(self, workload: WorkloadBase) -> Dict[str, Any]:
def get_engine_constructor_args(self, workload: EvalWorkload) -> Dict[str, Any]:
"""Get the engine constructor arguments.
Args:
Expand Down Expand Up @@ -139,7 +139,7 @@ def max_model_len(self) -> Optional[int]:
"""The maximum model length set by the engine."""
return self.engine_kwargs.get("max_model_len", None)

def get_engine_constructor_args(self, workload: WorkloadBase):
def get_engine_constructor_args(self, workload: EvalWorkload):
from vllm import PoolingParams, SamplingParams
from vllm.config import PoolerConfig

Expand Down
17 changes: 5 additions & 12 deletions skythought/skythought_evals/batch/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .env_config import EnvConfig
from .logging import get_logger
from .tokenizer import Detokenizer
from .workload import WorkloadBase
from .workload import EvalWorkload

if TYPE_CHECKING:
from ray.util.placement_group import PlacementGroup
Expand Down Expand Up @@ -48,7 +48,7 @@ def __init__(

@classmethod
def from_config(
cls, engine_cfg: Union[Dict[str, Any], str], workload: WorkloadBase, **kwargs
cls, engine_cfg: Union[Dict[str, Any], str], workload: EvalWorkload, **kwargs
):
"""Initialize the pipeline from a configuration file or dictionary.
Expand All @@ -72,16 +72,13 @@ def env_vars(self) -> Dict[str, Any]:

def load(
self,
ckpt_path: Optional[str] = None,
repartition_by_batch_size: bool = False,
) -> Dataset:
"""Use the given workload to load and process the dataset,
and then tokenize the prompts if needed. The processed dataset
will be repartitioned based on the number of replicas and batch size.
Args:
ckpt_path: The path to the checkpoint directory. If None, checkpointing
will be disabled.
repartition_by_batch_size: Whether to repartition the dataset by the
batch size for fault tolerance granularity. You should enable
this when the dataset is not from parquet and checkpointing is
Expand All @@ -92,7 +89,6 @@ def load(
"""
ds, num_blocks = self.workload.get_preprocessed_dataset(
self.env_config.batch_size,
ckpt_path,
repartition_by_batch_size,
)
if num_blocks is not None and num_blocks < self.num_replicas:
Expand All @@ -118,8 +114,6 @@ def load(
batch_size=self.env_config.batch_size,
)

ds = self.workload.postproc_after_tokenize(ds)

# If max tokens in prompt is not set in the workload and max_model_len is not set
# in the engine, we need to materialize the dataset to get the maximum tokens in prompt.
# This may hurt the overall throughput but may be memory efficient.
Expand All @@ -146,8 +140,8 @@ def load(
self.ds = ds
return ds

def __call__(self, workload: WorkloadBase):
self.workload: WorkloadBase = workload
def __call__(self, workload: EvalWorkload):
self.workload: EvalWorkload = workload
# Set the task to "embed" if sampling params are not given.
self.task_type_str: str = (
"auto" if self.workload.sampling_params is not None else "embed"
Expand Down Expand Up @@ -206,8 +200,7 @@ def run(
if dataset is not None:
self.ds = dataset
elif self.ds is None:
ckpt_path = f"{output_path}_ckpt" if output_path is not None else None
self.load(ckpt_path, repartition_by_batch_size)
self.load(repartition_by_batch_size)
assert self.ds is not None

num_gpus = self.engine_initializer.num_gpus
Expand Down
11 changes: 11 additions & 0 deletions skythought/skythought_evals/batch/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,17 @@ def get_preprocessed_dataset(
)
return ds.map(mapper_fn), num_blocks

def tokenizer_constructor_kwargs(self, model: str):
"""Return the keyword arguments for tokenizer constructor.
Args:
model: The model name.
Returns:
The keyword arguments for tokenizer constructor.
"""
return {"model": model}

def parse_row_with_carryover_input(self, row: dict[str, Any]) -> dict[str, Any]:
"""Same as parse_row but carries over the input keys that are not in the output row.
Expand Down
5 changes: 1 addition & 4 deletions skythought/skythought_evals/inference_and_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def main():
set_seed(args.seed)

# use os to enable hf_transfer for model download
if not os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", None) in ["1", "True"]:
if os.environ.get("HF_HUB_ENABLE_HF_TRANSFER", None) not in ["1", "True"]:
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

if args.task not in TASK_NAMES_TO_YAML:
Expand Down Expand Up @@ -603,9 +603,6 @@ def main():
return
else:
if args.use_ray:
# disable pyarrow warnings
data_ctx = ray.data.DataContext.get_current()
data_ctx.enable_fallback_to_arrow_object_ext_type = True
llm = None
else:
llm = (
Expand Down
1 change: 0 additions & 1 deletion skythought/skythought_evals/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,3 @@ scipy
datasets
latex2sympy2
pydantic
hf_transfer

0 comments on commit 50e79a0

Please sign in to comment.