Skip to content

Commit

Permalink
Add batched inference
Browse files Browse the repository at this point in the history
  • Loading branch information
lewtun committed May 3, 2024
1 parent 568e7b3 commit 9362eff
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 47 deletions.
8 changes: 1 addition & 7 deletions examples/scripts/dpo_online.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,24 +182,18 @@ def process(row):

judge = HuggingFaceJudge()
prompts_ds = load_dataset(args.dataset_name, split="test").shuffle(seed=42).select(range(64))
# prompts_ds = prompts_ds.map(
# lambda x: {
# "prompt": tokenizer.apply_chat_template(x["chosen"][:-1], tokenize=False, add_generation_prompt=True)
# }
# )
win_rate_callback = WinRateCallback(
prompts=prompts_ds["prompt"],
judge=judge,
generation_config=GenerationConfig(
temperature=0.9,
do_sample=True,
num_return_sequences=1,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=512,
),
trainer=trainer,
# batch_size=4,
batch_size=4,
)
trainer.add_callback(win_rate_callback)

Expand Down
70 changes: 36 additions & 34 deletions trl/trainer/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,42 +26,47 @@ def __init__(
generation_config: GenerationConfig,
judge,
trainer,
batch_size: int = 4,
):
self.prompts = prompts
self.prompts = [
trainer.tokenizer.apply_chat_template(
[{"role": "user", "content": p}], tokenize=False, add_generation_prompt=True
)
for p in prompts
]
self.generation_config = generation_config
self.completions = []
self.judge = judge
self.ref_completions = []
self.trainer = trainer
self.batch_size = batch_size

def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
model = self.trainer.model_wrapped
tokenizer = kwargs["tokenizer"]
tokenizer.padding_side = "left"
accelerator = self.trainer.accelerator

with accelerator.split_between_processes(self.prompts, apply_padding=True) as prompts:
# local_dataset = Dataset.from_dict(prompts)

with unwrap_model_for_generation(model, accelerator) as unwrapped_model:
unwrapped_model.eval()
for prompt in tqdm(prompts, desc="Generating ref completions for win rate"):
# tokenized_prompt = tokenizer(prompt, return_tensors="pt").to(model.device)
tokenized_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
).to(model.device)
generation = unwrapped_model.generate(
**tokenized_prompt,
for idx in tqdm(
range(0, len(prompts), self.batch_size), desc="Generating reference model completions for win rate"
):
batch = prompts[idx : idx + self.batch_size]
tokenized_batch = tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(
model.device
)
generations = unwrapped_model.generate(
**tokenized_batch,
generation_config=self.generation_config,
)
padded_prompt_length = tokenized_prompt.input_ids.shape[1]
generation = generation[:, padded_prompt_length:]
text_generations = tokenizer.batch_decode(generation, skip_special_tokens=True)
for prompt, generation in zip(tokenized_batch.input_ids, generations):
# Remove prompt from generation
generation = generation[len(prompt) :]
completion = tokenizer.decode(generation, skip_special_tokens=True)
self.ref_completions.append(completion)

ref_response = text_generations[0]
self.ref_completions.append(ref_response)
unwrapped_model.train()

def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
Expand All @@ -74,27 +79,24 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra

with unwrap_model_for_generation(model, accelerator) as unwrapped_model:
unwrapped_model.eval()
for idx, prompt in enumerate(tqdm(prompts, desc="Generating completions for win rate")):
# tokenized_prompt = tokenizer(prompt, return_tensors="pt").to(model.device)
tokenized_prompt = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True,
return_tensors="pt",
return_dict=True,
).to(model.device)
for idx in tqdm(range(0, len(prompts), self.batch_size), desc="Generating completions for win rate"):
batch = prompts[idx : idx + self.batch_size]
tokenized_batch = tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to(
model.device
)
generations = unwrapped_model.generate(
**tokenized_prompt,
**tokenized_batch,
generation_config=self.generation_config,
)
padded_prompt_length = tokenized_prompt.input_ids.shape[1]
generations = generations[:, padded_prompt_length:]
text_generations = tokenizer.batch_decode(generations, skip_special_tokens=True)

response0 = text_generations[0]
response1 = self.ref_completions[idx]
for batch_idx, (prompt, generation) in enumerate(zip(tokenized_batch.input_ids, generations)):
# Remove prompt from generation
generation = generation[len(prompt) :]
response_0 = tokenizer.decode(generation, skip_special_tokens=True)
response_1 = self.ref_completions[idx + batch_idx]
annotation_batch["completions"].append([response_0, response_1])

annotation_batch["completions"].append([response0, response1])
unwrapped_model.train()

# TODO, rerun with order or responses swapped and average
results_dict = self.judge.judge_batch(annotation_batch["prompts"], annotation_batch["completions"])
results_dict = Dataset.from_dict(
Expand Down
11 changes: 5 additions & 6 deletions trl/trainer/judges.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os
import random
from abc import ABC, abstractmethod
Expand All @@ -6,7 +7,6 @@

from accelerate import Accelerator
from huggingface_hub import InferenceClient
import logging
from requests import HTTPError

from ..import_utils import is_llmblender_available, is_openai_available
Expand Down Expand Up @@ -96,7 +96,9 @@ def get_response(self, content: str) -> str:

def judge(self, prompt: str, completion_pair: List[str], shuffle_order: bool) -> int:
if self.max_tries <= 0:
logging.info(f"Max retries reached for prompt {prompt}. Returning random choice.")
logging.info(
f"Max retries reached for prompt:\n\n{prompt}\nand completion pair:\n\n{completion_pair}\n\nReturning random choice."
)
return random.choice([0, 1])

shuffle_index = 0 if not shuffle_order else random.choice([0, 1])
Expand All @@ -106,19 +108,16 @@ def judge(self, prompt: str, completion_pair: List[str], shuffle_order: bool) ->
reply = self.get_response(content)
reply = reply.strip()

# First answer
if reply in [
"0",
]:
return shuffle_index
# Second answer
elif reply in [
"1",
]:
return 1 - shuffle_index
# Unknown reply
else:
logging.info(f"Judge gave response {reply} instead of the expected 0 or 1. Retrying.")
logging.info(f"Judge gave response `{reply}` instead of the expected 0 or 1. Retrying.")
self.max_tries -= 1
return self.judge(prompt, completion_pair, shuffle_order)

Expand Down

0 comments on commit 9362eff

Please sign in to comment.