From 9362eff6350a1396ac3bd738d037afd5bb01ac6c Mon Sep 17 00:00:00 2001 From: Lewis Tunstall Date: Fri, 3 May 2024 09:04:25 +0000 Subject: [PATCH] Add batched inference --- examples/scripts/dpo_online.py | 8 +--- trl/trainer/callbacks.py | 70 +++++++++++++++++----------------- trl/trainer/judges.py | 11 +++--- 3 files changed, 42 insertions(+), 47 deletions(-) diff --git a/examples/scripts/dpo_online.py b/examples/scripts/dpo_online.py index 6384f88a15..eb8465d044 100644 --- a/examples/scripts/dpo_online.py +++ b/examples/scripts/dpo_online.py @@ -182,24 +182,18 @@ def process(row): judge = HuggingFaceJudge() prompts_ds = load_dataset(args.dataset_name, split="test").shuffle(seed=42).select(range(64)) - # prompts_ds = prompts_ds.map( - # lambda x: { - # "prompt": tokenizer.apply_chat_template(x["chosen"][:-1], tokenize=False, add_generation_prompt=True) - # } - # ) win_rate_callback = WinRateCallback( prompts=prompts_ds["prompt"], judge=judge, generation_config=GenerationConfig( temperature=0.9, do_sample=True, - num_return_sequences=1, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, max_new_tokens=512, ), trainer=trainer, - # batch_size=4, + batch_size=4, ) trainer.add_callback(win_rate_callback) diff --git a/trl/trainer/callbacks.py b/trl/trainer/callbacks.py index 8f59349e85..ca56f0cae2 100644 --- a/trl/trainer/callbacks.py +++ b/trl/trainer/callbacks.py @@ -26,42 +26,47 @@ def __init__( generation_config: GenerationConfig, judge, trainer, + batch_size: int = 4, ): - self.prompts = prompts + self.prompts = [ + trainer.tokenizer.apply_chat_template( + [{"role": "user", "content": p}], tokenize=False, add_generation_prompt=True + ) + for p in prompts + ] self.generation_config = generation_config self.completions = [] self.judge = judge self.ref_completions = [] self.trainer = trainer + self.batch_size = batch_size def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): model = self.trainer.model_wrapped tokenizer = kwargs["tokenizer"] + tokenizer.padding_side = "left" accelerator = self.trainer.accelerator with accelerator.split_between_processes(self.prompts, apply_padding=True) as prompts: - # local_dataset = Dataset.from_dict(prompts) - with unwrap_model_for_generation(model, accelerator) as unwrapped_model: unwrapped_model.eval() - for prompt in tqdm(prompts, desc="Generating ref completions for win rate"): - # tokenized_prompt = tokenizer(prompt, return_tensors="pt").to(model.device) - tokenized_prompt = tokenizer.apply_chat_template( - [{"role": "user", "content": prompt}], - add_generation_prompt=True, - return_tensors="pt", - return_dict=True, - ).to(model.device) - generation = unwrapped_model.generate( - **tokenized_prompt, + for idx in tqdm( + range(0, len(prompts), self.batch_size), desc="Generating reference model completions for win rate" + ): + batch = prompts[idx : idx + self.batch_size] + tokenized_batch = tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to( + model.device + ) + generations = unwrapped_model.generate( + **tokenized_batch, generation_config=self.generation_config, ) - padded_prompt_length = tokenized_prompt.input_ids.shape[1] - generation = generation[:, padded_prompt_length:] - text_generations = tokenizer.batch_decode(generation, skip_special_tokens=True) + for prompt, generation in zip(tokenized_batch.input_ids, generations): + # Remove prompt from generation + generation = generation[len(prompt) :] + completion = tokenizer.decode(generation, skip_special_tokens=True) + self.ref_completions.append(completion) - ref_response = text_generations[0] - self.ref_completions.append(ref_response) unwrapped_model.train() def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): @@ -74,27 +79,24 @@ def on_evaluate(self, args: TrainingArguments, state: TrainerState, control: Tra with unwrap_model_for_generation(model, accelerator) as unwrapped_model: unwrapped_model.eval() - for idx, prompt in enumerate(tqdm(prompts, desc="Generating completions for win rate")): - # tokenized_prompt = tokenizer(prompt, return_tensors="pt").to(model.device) - tokenized_prompt = tokenizer.apply_chat_template( - [{"role": "user", "content": prompt}], - add_generation_prompt=True, - return_tensors="pt", - return_dict=True, - ).to(model.device) + for idx in tqdm(range(0, len(prompts), self.batch_size), desc="Generating completions for win rate"): + batch = prompts[idx : idx + self.batch_size] + tokenized_batch = tokenizer(batch, return_tensors="pt", padding=True, truncation=True).to( + model.device + ) generations = unwrapped_model.generate( - **tokenized_prompt, + **tokenized_batch, generation_config=self.generation_config, ) - padded_prompt_length = tokenized_prompt.input_ids.shape[1] - generations = generations[:, padded_prompt_length:] - text_generations = tokenizer.batch_decode(generations, skip_special_tokens=True) - - response0 = text_generations[0] - response1 = self.ref_completions[idx] + for batch_idx, (prompt, generation) in enumerate(zip(tokenized_batch.input_ids, generations)): + # Remove prompt from generation + generation = generation[len(prompt) :] + response_0 = tokenizer.decode(generation, skip_special_tokens=True) + response_1 = self.ref_completions[idx + batch_idx] + annotation_batch["completions"].append([response_0, response_1]) - annotation_batch["completions"].append([response0, response1]) unwrapped_model.train() + # TODO, rerun with order or responses swapped and average results_dict = self.judge.judge_batch(annotation_batch["prompts"], annotation_batch["completions"]) results_dict = Dataset.from_dict( diff --git a/trl/trainer/judges.py b/trl/trainer/judges.py index 25c7954e69..67852863e1 100644 --- a/trl/trainer/judges.py +++ b/trl/trainer/judges.py @@ -1,3 +1,4 @@ +import logging import os import random from abc import ABC, abstractmethod @@ -6,7 +7,6 @@ from accelerate import Accelerator from huggingface_hub import InferenceClient -import logging from requests import HTTPError from ..import_utils import is_llmblender_available, is_openai_available @@ -96,7 +96,9 @@ def get_response(self, content: str) -> str: def judge(self, prompt: str, completion_pair: List[str], shuffle_order: bool) -> int: if self.max_tries <= 0: - logging.info(f"Max retries reached for prompt {prompt}. Returning random choice.") + logging.info( + f"Max retries reached for prompt:\n\n{prompt}\nand completion pair:\n\n{completion_pair}\n\nReturning random choice." + ) return random.choice([0, 1]) shuffle_index = 0 if not shuffle_order else random.choice([0, 1]) @@ -106,19 +108,16 @@ def judge(self, prompt: str, completion_pair: List[str], shuffle_order: bool) -> reply = self.get_response(content) reply = reply.strip() - # First answer if reply in [ "0", ]: return shuffle_index - # Second answer elif reply in [ "1", ]: return 1 - shuffle_index - # Unknown reply else: - logging.info(f"Judge gave response {reply} instead of the expected 0 or 1. Retrying.") + logging.info(f"Judge gave response `{reply}` instead of the expected 0 or 1. Retrying.") self.max_tries -= 1 return self.judge(prompt, completion_pair, shuffle_order)