Skip to content

Commit

Permalink
cleanup entrypoint
Browse files Browse the repository at this point in the history
Signed-off-by: SumanthRH <[email protected]>
  • Loading branch information
SumanthRH committed Feb 6, 2025
1 parent 24f0f14 commit c80c3c1
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 37 deletions.
84 changes: 52 additions & 32 deletions skythought/skythought_evals/inference_and_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@
import os
from concurrent.futures import ProcessPoolExecutor, as_completed
from functools import partial
from typing import Dict, Tuple, Any
from typing import Dict, Tuple

import numpy as np
import ray
from openai import OpenAI
from skythought_evals.batch import Pipeline, init_engine_from_config
from skythought_evals.batch.env_config import EnvConfig
from skythought_evals.batch.workload import EvalWorkload, load_config_from_path as load_rayllm_config_from_path
from openai import OpenAI
from skythought_evals.batch.workload import EvalWorkload
from skythought_evals.batch.workload import (
load_config_from_path as load_rayllm_config_from_path,
)
from skythought_evals.models import ModelConfig, get_system_prompt_keys
from skythought_evals.tasks import (
TASK_HANDLER_MAP,
Expand All @@ -30,6 +33,7 @@
module_dir = os.path.dirname(os.path.abspath(__file__))
DEFAULT_RAY_CONFIG_RELATIVE_PATH = "ray_configs/ray_config.yaml"


class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
Expand Down Expand Up @@ -87,25 +91,26 @@ def fetch_responses_ray(conversations, max_tokens, temp, args):
responses = ds.materialize()
return responses

def _parse_response_for_idx(response: Response, sample_idx: int, args) -> Tuple[SingleParsedResponse, Dict[str, int]]:

def _parse_response_for_idx(
response: Response, sample_idx: int, args
) -> Tuple[SingleParsedResponse, Dict[str, int]]:
if args.model.startswith("openai"):
content = response.response.strip()
else:
content = response.response[sample_idx].strip()
response_entry = SingleParsedResponse(content=content)

if args.model.startswith("openai"):
token_usage_for_response = {
"completion_tokens": response.num_completion_tokens,
"prompt_tokens": response.num_input_tokens,
}
"completion_tokens": response.num_completion_tokens,
"prompt_tokens": response.num_input_tokens,
}

else:
token_usage_for_response = {
"completion_tokens": response.num_completion_tokens[
sample_idx
],
"prompt_tokens": response.num_input_tokens,
"completion_tokens": response.num_completion_tokens[sample_idx],
"prompt_tokens": response.num_input_tokens,
}
return response_entry, token_usage_for_response

Expand All @@ -132,7 +137,9 @@ def inference(llm, conversations, max_tokens, temp, args):

responses = [Response.from_openai_response(response) for response in responses]
else:
sampling_params = SamplingParams(max_tokens=max_tokens, temperature=temp, n=args.n)
sampling_params = SamplingParams(
max_tokens=max_tokens, temperature=temp, n=args.n
)
responses = llm.chat(
messages=conversations, sampling_params=sampling_params, use_tqdm=True
)
Expand Down Expand Up @@ -170,23 +177,27 @@ def perform_inference_and_check(
continue

responses = inference(llm, conversations, max_tokens, temp, args)

total_correct = 0
total_finish = 0
with ProcessPoolExecutor(max_workers=32) as executor:
future_to_task = {}
token_usages = {}
for idx, response in enumerate(responses):
for sample_idx in range(args.n):
# response_entry at this point doesn't contain correctness check.
response_entry, token_usage_for_response = _parse_response_for_idx(response, sample_idx, args)
if idx not in token_usages:
# response_entry at this point doesn't contain correctness check.
response_entry, token_usage_for_response = _parse_response_for_idx(
response, sample_idx, args
)
if idx not in token_usages:
token_usages[idx] = []
token_usages[idx].append(token_usage_for_response)
# submit correctness check for response
future_to_task[
executor.submit(
handler.update_results, remaining_data[idx], response_entry.content
handler.update_results,
remaining_data[idx],
response_entry.content,
)
] = (idx, sample_idx)

Expand All @@ -196,8 +207,8 @@ def perform_inference_and_check(
desc="Processing Generations",
):
idx, sample_idx = future_to_task[future]
# TODO (sumanthrh): the returned entry is currently a dict and can be confusing.
# this should also be a ParsedResponse object.
# TODO (sumanthrh): the returned entry is currently a dict and can be confusing.
# this should also be a ParsedResponse object.
response_entry: dict = future.result()
total_correct += response_entry["correctness"]
total_finish += 1
Expand All @@ -212,15 +223,15 @@ def perform_inference_and_check(
prompt = conversations[idx][1]["content"]
results[problem_key]["prompt"] = prompt
results[problem_key]["input_conversation"] = conversations[idx]

if str(temp) not in results[problem_key]["responses"]:
results[problem_key]["responses"][str(temp)] = []

# args.n responses can come in any order, but we can safely ignore
# sample idx and just save as results come in
results[problem_key]["responses"][str(temp)].append(response_entry)
# do this only once per problem/idx
if str(temp) not in results[problem_key]["token_usages"]:
if str(temp) not in results[problem_key]["token_usages"]:
results[problem_key]["token_usages"][str(temp)] = token_usages[idx]

print(f"Final acc: {total_correct}/{total_finish}")
Expand Down Expand Up @@ -259,7 +270,9 @@ def perform_inference_and_check(
else 0
),
"avg_prompt_tokens": (
round(total_prompt_tokens / num_responses_total, 3) if total_prompt_tokens else 0
round(total_prompt_tokens / num_responses_total, 3)
if total_prompt_tokens
else 0
),
}

Expand Down Expand Up @@ -404,7 +417,9 @@ def perform_inference_and_save(
token_usages = []
completion_token = 0
for sample_idx in range(args.n):
response_entry, token_usage_for_response = _parse_response_for_idx(response, sample_idx, args)
response_entry, token_usage_for_response = _parse_response_for_idx(
response, sample_idx, args
)
token_usages.append(token_usage_for_response)
completion_token += token_usage_for_response["completion_tokens"]
response_entries.append(response_entry)
Expand Down Expand Up @@ -558,7 +573,12 @@ def main():
help="Ray configuration file if using ray for scaling inference. By default, we use the example in ray_configs/ray_config.yaml",
)
parser.add_argument(
"--dtype", type=str, choices=["float32", "auto", "float16", "bfloat16"], help="dtype for inference with vLLM. Full-precision by default. 'auto' refers to automatically inferring dtype for the model", default="float32"
"--dtype",
type=str,
choices=["float32", "auto", "float16", "bfloat16"],
help="dtype for inference with vLLM. Full-precision by default."
"'auto' refers to automatically inferring dtype for the model",
default="float32",
)
args = parser.parse_args()
# load ray config
Expand Down Expand Up @@ -602,7 +622,7 @@ def main():
# create result dir if not exists
if args.result_dir and not os.path.exists(args.result_dir):
os.makedirs(args.result_dir)
temperature_str = ','.join(map(str, temperatures))
temperature_str = ",".join(map(str, temperatures))
file_suffix = f"{model_config.name}_{args.task}_{args.split}_{args.subset}_{args.filter_difficulty}_{args.start}_{args.end}_t{temperature_str}"
if (
args.math_difficulty_lower_bound is not None
Expand All @@ -624,9 +644,7 @@ def main():
args.math_difficulty_lower_bound is not None
or args.math_difficulty_upper_bound is not None
):
converted_file = (
f"{args.result_dir}/converted_{file_suffix}.json"
)
converted_file = f"{args.result_dir}/converted_{file_suffix}.json"
else:
converted_file = f"{args.result_dir}/converted_{file_suffix}.json"
if os.path.exists(converted_file):
Expand All @@ -640,7 +658,9 @@ def main():
llm = (
OpenAI()
if args.model.startswith("openai")
else LLM(model=args.model, tensor_parallel_size=args.tp, dtype=args.dtype)
else LLM(
model=args.model, tensor_parallel_size=args.tp, dtype=args.dtype
)
)
if args.inference:
perform_inference_and_save(
Expand Down
2 changes: 1 addition & 1 deletion skythought/skythought_evals/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional, Union

import yaml
from pydantic import BaseModel, Field, PrivateAttr, field_validator, model_validator
from pydantic import BaseModel, Field, PrivateAttr, model_validator

MODEL_CONFIG_FILE_PATH = Path(__file__).parent / "model_configs.yaml"
# cache the configs in a global var
Expand Down
12 changes: 8 additions & 4 deletions skythought/skythought_evals/util/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,15 @@ def from_vllm_response(cls, response) -> "Response":
)


@dataclass
@dataclass
class SingleParsedResponse:
content: str
content: str
correctness: Optional[bool] = None
reason: Optional[str] = None
reason: Optional[str] = None

def as_dict(self):
return {"content": self.content, "correctness": self.correctness, "reason": self.reason}
return {
"content": self.content,
"correctness": self.correctness,
"reason": self.reason,
}

0 comments on commit c80c3c1

Please sign in to comment.