Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[evals] Add support for scaling evals and inference with ray #63

Merged
merged 33 commits into from
Feb 6, 2025

Conversation

erictang000
Copy link
Collaborator

@erictang000 erictang000 commented Feb 3, 2025

What does this PR do?

This PR adds support for using ray to speed up evals and data generation. Currently we use are using a preliminary version of using ray data + vllm while we wait for the code at ray.llm to be fully open sourced (coming in the next 1-2 weeks), after which we will migrate over.

Speedups

image Example speedups using `ray` can be seen above - we found that for a single 8xA100 or 8xH100 node, setting `tp=4` with 2 replicas or `tp=2` with 4 replicas using ray can be much faster than a single `tp=8` replica, especially for longer evals like MMLU. image

We also see faster data generation for sampling n parallel generations (above for Deepseek distilled Qwen-7B). Same inference steps for 32k max tokens for AIME and n=128 with the Qwen math repo using a single tp=4 replica takes ~10 hours.

How to Use

To use the new path: simply add --use_ray to existing commands and set relevant scaling parameters in --ray_config.

Reasonable defaults and examples of how to set advanced vllm engine arguments are provided in ray_configs/ray_config.yaml. For example, to run the Math-500 eval with Sky-T1-32B-Preview, you can use the following command

python inference_and_check.py --model NovaSky-AI/Sky-T1-32B-Preview --task math500  --split test --max_tokens 8192 --use_ray --ray_config ray_configs/ray_config.yaml

where the ray_config looks like:

llm_engine: vllm # currently only vllm supported
accelerator_type: A100-80G # accelerator name as specified here: https://docs.ray.io/en/master/ray-core/accelerator-types.html#accelerator-types
engine_kwargs: # vllm engine kwargs 
  tensor_parallel_size: 2
  gpu_memory_utilization: 0.9
runtime_env:
  env_vars:
    VLLM_ATTENTION_BACKEND: "FLASH_ATTN"
env_config:
  num_replicas: 4 # number of vllm replicas 
  batch_size: 128 # ray data internal batch size (used for map_batches call internally). Should usually be set to a value in [64, 128, 256] for best performance.

@lynnliu030 lynnliu030 self-requested a review February 3, 2025 04:20
Copy link
Collaborator

@kouroshHakha kouroshHakha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have any unitests in this repo so some stuff to verify manually:

  • does the regular single node cli still work after these changes?
  • does it still work with openAI models?

There is a bit of extra stuff that we can trim down to not confuse people in workload.py

@erictang000
Copy link
Collaborator Author

  • does the regular single node cli still work after these changes?
  • does it still work with openAI models?

Yep, checked with the new updates and made sure everything is working e2e for both inference_and_check as well as inference_and_save with n = 1 and n > 1 for all 3 paths for getting completions (oai, vllm, ray + vllm)

Comment on lines 95 to 96

responses = copy.deepcopy(responses)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a NOTE + TODO comment here for now explaining the issue we saw?

Copy link
Collaborator Author

@erictang000 erictang000 Feb 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bug details:

  • a new Response object that's just a python dataclass with str, int, int attributes is initialized from the values of the ds.iter_rows() of a ray dataset
  • these responses are processed in a ProcessPoolExecutor, but when we exit the context of the executor, and it tries to clean up the response objects, we runs into a SIGSEV on the ray object store level for some reason (see below)

Traceback for posterity

(raylet) *** SIGSEGV received at time=1738696160 on cpu 214 ***                                                                                                                                                                     
(raylet) PC: @     0x56080cd7f1ae  (unknown)  plasma::ReadReleaseRequest()
(raylet)     @     0x7fcb9fbaf520       4656  (unknown)
(raylet)     @     0x56080cd5715f       1456  plasma::PlasmaStore::ProcessMessage()
(raylet)     @     0x56080cd50f15         32  std::_Function_handler<>::_M_invoke()
(raylet)     @     0x56080cd86d33       1280  plasma::Client::Create()::{lambda()#1}::operator()()
(raylet)     @     0x56080cf56aad       1376  ray::ClientConnection::ProcessMessage()
(raylet)     @     0x56080cf6de98       1168  EventTracker::RecordExecution()
(raylet)     @     0x56080cf58fb8        400  boost::asio::detail::reactive_socket_recv_op<>::do_complete()
(raylet)     @     0x56080d557f9b        128  boost::asio::detail::scheduler::do_run_one()
(raylet)     @     0x56080d55a529        288  boost::asio::detail::scheduler::run()
(raylet)     @     0x56080d55aa42         96  boost::asio::io_context::run()
(raylet)     @     0x56080cd50b20       1424  plasma::PlasmaStoreRunner::Start()
(raylet)     @     0x56080ccc4b05        208  std::thread::_State_impl<>::_M_run()
(raylet)     @     0x56080d6bafb0  258531312  execute_native_thread_routine
(raylet)     @ ... and at least 3 more frames
(raylet) {"asctime":"2025-02-04 19:09:20,137","levelname":"E","message":"*** SIGSEGV received at time=1738696160 on cpu 214 ***","component":"raylet","filename":"logging.cc","lineno":447}
(raylet) {"asctime":"2025-02-04 19:09:20,137","levelname":"E","message":"    @ ... and at least 3 more frames","component":"raylet","filename":"logging.cc","lineno":447}
(raylet) {"asctime":"2025-02-04 19:09:20,137","levelname":"E","message":"PC: @     0x56080cd7f1ae  (unknown)  plasma::ReadReleaseRequest()","component":"raylet","filename":"logging.cc","lineno":447}
(raylet)     @     0x56080cd873f8         48  std::_Function_handler<>::_M_invoke()
(raylet) {"asctime":"2025-02-04 19:09:20,137","levelname":"E","message":"    @     0x56080d6bafb0  258531312  execute_native_thread_routine","component":"raylet","filename":"logging.cc","lineno":447} [repeated 16x across cluster]

Copy link
Collaborator

@SumanthRH SumanthRH left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!!

@SumanthRH SumanthRH merged commit 806f09c into NovaSky-AI:main Feb 6, 2025
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants