diff --git a/generator/confs/torchtune-llama3-8B_full.yaml b/generator/confs/torchtune-llama3-8B_full.yaml new file mode 100644 index 0000000..3cd35c0 --- /dev/null +++ b/generator/confs/torchtune-llama3-8B_full.yaml @@ -0,0 +1,88 @@ +# Config for multi-device full finetuning in full_finetune_distributed.py +# using a Llama3 8B Instruct model +# +# This config assumes that you've run the following command before launching +# this run: +# tune download meta-llama/Meta-Llama-3-8B-Instruct --output-dir /tmp/Meta-Llama-3-8B-Instruct --hf-token +# +# To launch on 4 devices, run the following command from root: +# tune run --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full checkpointer.checkpoint_dir= +# +# This config works best when the model is being fine-tuned on 2+ GPUs. +# Single device full finetuning requires more memory optimizations. It's +# best to use 8B_full_single_device.yaml for those cases + + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: ./models/Meta-Llama-3-8B/original/tokenizer.model + +# Dataset +dataset: + _component_: torchtune.datasets.instruct_dataset + source: csv + data_files: state_tactic_pairs.csv + split: train + template: generator.template.StateTacticPairTemplate + train_on_input: False + max_seq_len: 4096 +seed: null +shuffle: True + +# Model Arguments +model: + _component_: torchtune.models.llama3.llama3_8b + +checkpointer: + _component_: torchtune.utils.FullModelHFCheckpointer + checkpoint_dir: ./models/Meta-Llama-3-8B/ + checkpoint_files: [ + model-00001-of-00004.safetensors, + model-00002-of-00004.safetensors, + model-00003-of-00004.safetensors, + model-00004-of-00004.safetensors, + ] + recipe_checkpoint: null + output_dir: ./models/Meta-Llama-3-8B-finetuned/ + model_type: LLAMA3 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 4 +epochs: 1 + +optimizer: + _component_: torch.optim.AdamW + lr: 2e-5 + foreach: False + +loss: + _component_: torch.nn.CrossEntropyLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 + + +# Training env +device: cuda + +# Memory management +enable_activation_checkpointing: True +memory_efficient_fsdp_wrap: True + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.utils.metric_logging.WandBLogger + project: ReProver + log_dir: ${output_dir} +output_dir: ./logs/leandojo-llama3-finetune +log_every_n_steps: 1 +log_peak_memory_stats: false diff --git a/generator/full_finetune_distributed.py b/generator/full_finetune_distributed.py new file mode 100644 index 0000000..0e637cd --- /dev/null +++ b/generator/full_finetune_distributed.py @@ -0,0 +1,609 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys +import time + +from functools import partial +from typing import Any, Dict, Optional, Tuple +from warnings import warn + +import torch +from omegaconf import DictConfig, ListConfig + +from torch import nn +from torch.distributed import init_process_group +from torch.distributed.fsdp import ( + CPUOffload, + FullOptimStateDictConfig, + FullStateDictConfig, + FullyShardedDataParallel as FSDP, + StateDictType, +) +from torch.optim import Optimizer +from torch.utils.data import DataLoader, DistributedSampler + +from torchtune import config, modules, utils +from torchtune.datasets import ConcatDataset +from torchtune.recipe_interfaces import FTRecipeInterface +from torchtune.utils.activations import apply_selective_activation_checkpointing + +from tqdm import tqdm + + +log = utils.get_logger("DEBUG") + + +class FullFinetuneRecipeDistributed(FTRecipeInterface): + """ + Full finetuning recipe for dense transformer-based LLMs such as Llama2. This recipe supports + distributed training and can be run on a single node (1 to 8 GPUs). + + Features: + - FSDP. Supported using PyTorch's FSDP APIs. DDP is currently not supported. Training on CPU + is not supported. + + - Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` + flag. Activation checkpointing helps reduce the memory footprint since we no longer keep + activations in memory and instead recompute them during the backward pass. This is especially + helpful for larger batch sizes when you're memory constrained. But these savings in memory + come at the cost of training performance. In most cases training can slow-down quite a bit as + a result of this activation recomputation. + + - Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` + flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In + most cases this should halve the memory footprint of full precision (fp32) training, without + loss in model quality (will depend on the model, training data and other settings). For + GPUs which do not support bfloat16, we fall back to fp32. Mixed precision training and fp16 + precision are currently not supported. + + - Gradient Accumulation. You can simulate larger batch sizes by accumulating gradients. This is + controlled using the ``gradient_accumulation_steps`` flag. + + Total Batch Size = batch_size * number of GPUs * gradient accumulation steps. + + For example: with batch_size=1, nproc_per_node=2 and gradient_accumulation_steps=32 we get a + total batch size of 64. + + Gradient accumulation is especially useful when you are memory constrained. In this case, + accumulating gradients might give you better training speed than enabling activation + checkpointing. + + - Checkpointing. Model weights are checkpointed both at the end of each epoch and at the end of + training. Optimizer state and recipe state (seed, total_epochs, number of epochs run etc) are + only saved at the end of a given epoch and used in case of resuming training. + + Resuming training is controlled by the ``resume_from_checkpoint`` flag. Mid-epoch checkpointing is + currently not supported. + + For more details on the checkpointer, please take a look at + our checkpointer deepdive (https://pytorch.org/torchtune/main/deep_dives/checkpointer.html). + + - Logging. Terminal, Disk, WandB and TensorBoard are all supported. + + For a full list of example configs for this recipe, run ``tune ls`` on the command line. Each config + has example commands for how to kick-off training. + + Args: + cfg (DictConfig): OmegaConf object parsed from yaml file + + Raises: + ValueError: If ``dtype`` is set to fp16. + """ + + def __init__(self, cfg: DictConfig) -> None: + + self._device = utils.get_device(device=cfg.device) + self._dtype = utils.get_dtype(cfg.dtype, device=self._device) + + if self._dtype == torch.float16: + raise ValueError( + "full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead." + ) + + if ( + cfg.get("fsdp_cpu_offload", False) + and cfg.get("fused", False) + and not utils.torch_version_ge("2.4.0") + ): + raise RuntimeError( + "Using fused optimizer on CPU is only supported in PyTorch nightly." + ) + + # logging attributes + self._output_dir = cfg.output_dir + self._log_every_n_steps = cfg.get("log_every_n_steps", 1) + self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False) + + # _is_rank_zero is used primarily for logging. In the future, the logger + # should directly take care of this + _, rank = utils.get_world_size_and_rank() + self._is_rank_zero = rank == 0 + + # Training cfg + self._resume_from_checkpoint = cfg.resume_from_checkpoint + self._gradient_accumulation_steps = cfg.gradient_accumulation_steps + + # These are public properties which are updated by the checkpoint loader + # when ``resume_from_checkpoint`` is `True` or validated in tests + self.seed = utils.set_seed(seed=cfg.seed) + self.epochs_run = 0 + self.total_epochs = cfg.epochs + self.max_steps_per_epoch = cfg.max_steps_per_epoch + self.global_step = 0 + + def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]: + """ + Extract the checkpoint state from file and validate. If resume_from_checkpoint + is True, this also includes the recipe state. + """ + self._checkpointer = config.instantiate( + cfg_checkpointer, + resume_from_checkpoint=self._resume_from_checkpoint, + ) + checkpoint_dict = self._checkpointer.load_checkpoint() + + if self._resume_from_checkpoint: + self._update_recipe_state(checkpoint_dict) + return checkpoint_dict + + def _update_recipe_state(self, ckpt_dict: Dict[str, Any]) -> None: + """ + Updates the recipe state from checkpoint. + """ + try: + self.epochs_run = ckpt_dict[utils.EPOCHS_KEY] + + # on mismatch, warn the user and prevent the override + if self.seed != ckpt_dict[utils.SEED_KEY]: + warn( + message=( + "Config value for seed does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[utils.SEED_KEY]}" + ) + ) + self.seed = ckpt_dict[utils.SEED_KEY] + if self.max_steps_per_epoch != ckpt_dict[utils.MAX_STEPS_KEY]: + warn( + message=( + "Config value for max_steps_per_epoch does not match the checkpoint value, " + f"using the checkpoint value: {ckpt_dict[utils.MAX_STEPS_KEY]}" + ) + ) + self.max_steps_per_epoch = ckpt_dict[utils.MAX_STEPS_KEY] + + # on mismatch, warn the user but allow the override + if self.total_epochs != ckpt_dict[utils.TOTAL_EPOCHS_KEY]: + warn( + message=( + "Config value for total_epochs does not match the checkpoint value, " + f"using the config value: {self.total_epochs}" + ) + ) + + except KeyError as e: + raise KeyError( + "Checkpoint does not contain the required keys needed for updating recipe state. " + "Are you sure you passed in the right recipe checkpoint?" + ) from e + + def setup(self, cfg: DictConfig) -> None: + """ + Sets up the recipe state correctly. This includes setting recipe attributes based + on the ``resume_from_checkpoint`` flag. + """ + if self._is_rank_zero: + self._metric_logger = config.instantiate(cfg.metric_logger) + + # log config with parameter override + self._metric_logger.log_config(cfg) + + ckpt_dict = self.load_checkpoint(cfg.checkpointer) + + # ``_setup_model`` handles initialization and loading the state dict. This method + # should be called before ``_setup_optimizer`` since transforming the optimizer + # state dict requires the model + self._model = self._setup_model( + cfg_model=cfg.model, + enable_activation_checkpointing=cfg.enable_activation_checkpointing, + memory_efficient_fsdp_wrap=cfg.get("memory_efficient_fsdp_wrap", False), + fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), + model_state_dict=ckpt_dict[utils.MODEL_KEY], + ac_mode=cfg.get("ac_mode", None), + ac_option=cfg.get("ac_option", None), + ) + + self._tokenizer = config.instantiate(cfg.tokenizer) + + # _setup_optimizer should take in ckpt_dict only if training is resumed from + # checkpoint. Transforming the opt state dict is handled by this method + self._optimizer = self._setup_optimizer( + cfg_optimizer=cfg.optimizer, + opt_state_dict=( + ckpt_dict[utils.OPT_KEY] if self._resume_from_checkpoint else None + ), + ) + + self._loss_fn = config.instantiate(cfg.loss) + + # sampler and dataloader depend on the tokenizer and loss_fn and should be + # setup after both of these are initialized + self._sampler, self._dataloader = self._setup_data( + cfg_dataset=cfg.dataset, + shuffle=cfg.shuffle, + batch_size=cfg.batch_size, + ) + + # Finally update the recipe state which can only be correctly set after all of the + # other components have been initialized and updated. + # + # Number of training steps in each epoch depends on the number of batches produced + # by the dataloader, the max_steps_per_epoch param set by the user and the + # gradient_accumulation_steps param. This value is used for logging and tracking + # training state. The computation should happen after the dataloader has been setup + self._steps_per_epoch = ( + len(self._dataloader) // self._gradient_accumulation_steps + ) + if ( + self.max_steps_per_epoch is not None + and self.max_steps_per_epoch < self._steps_per_epoch + ): + self._steps_per_epoch = self.max_steps_per_epoch + self.global_step = self.epochs_run * self._steps_per_epoch + + def _setup_model( + self, + cfg_model: DictConfig, + enable_activation_checkpointing: bool, + memory_efficient_fsdp_wrap: bool, + fsdp_cpu_offload: bool, + model_state_dict: Dict[str, Any], + ac_mode: Optional[str] = None, + ac_option: Optional[int] = None, + ) -> nn.Module: + """ + Model initialization has some important considerations: + a. To minimize GPU peak memory, we load the model on CPU with the right + dtype. To ensure that we don't instantiate ``world_size`` number of models, + we initialize on meta_device for all ranks other than rank 0. + b. Rank 0 is also responsible for calling ``load_state_dict`` and loading the + model weights from checkpoint. + c. While wrapping the model with FSDP, we set ``sync_module_states`` + to TRUE and broadcast module params and buffers from rank 0. + d. The ``device_id`` param ensures that the FSDP initialization happens on + the correct device. + """ + if self._is_rank_zero: + log.info("FSDP is enabled. Instantiating Model on CPU for Rank 0 ...") + init_start = time.perf_counter() + + with utils.set_default_dtype(self._dtype): + model = config.instantiate(cfg_model) + + log.info( + f"Model instantiation took {time.perf_counter() - init_start:.2f} secs" + ) + + # Load both the model weights. This should happen only on Rank 0 + model.load_state_dict(model_state_dict) + + else: + # For non-zero ranks, load the model on meta device + with utils.set_default_dtype(self._dtype), torch.device("meta"): + model = config.instantiate(cfg_model) + + if self._dtype == torch.bfloat16: + model = model.to(torch.bfloat16) + + # We currently have two versions of activation checkpointing in this recipe + # for testing and BC purposes. ``enable_activation_checkpointing`` controls + # the older version of AC and this behavior is unchanged + # ac_mode and ac_option together control selective AC. This is only enabled + # when these are set AND ``enable_activation_checkpointing`` is set to False + # We'll clean this up as soon as testing of AC is complete + ac_mode = ac_mode + ac_option = ac_option + + if (not enable_activation_checkpointing) and (ac_mode is not None): + apply_selective_activation_checkpointing( + model, + ac_mode, + ac_option, + ) + + # Wrap the model with FSDP. This will ensure that the model is sharded + # across all available GPUs. + model = FSDP( + module=model, + auto_wrap_policy=utils.get_full_finetune_fsdp_wrap_policy( + memory_efficient_fsdp_wrap=memory_efficient_fsdp_wrap, + modules_to_wrap={modules.TransformerDecoderLayer}, + ), + cpu_offload=CPUOffload(offload_params=fsdp_cpu_offload), + sharding_strategy=torch.distributed.fsdp.ShardingStrategy.FULL_SHARD, + device_id=self._device, + # this recipe does not currently support mixed precision training + mixed_precision=None, + # Ensure we broadcast params and buffers from rank 0 + sync_module_states=True, + # Initialize empty modules on all non-zero ranks + param_init_fn=( + lambda module: ( + module.to_empty(device=torch.device("cuda"), recurse=False) + if not self._is_rank_zero + else None + ) + ), + ) + + # Ensure no params and buffers are on meta device + utils.validate_no_params_on_meta_device(model) + + # original activation checkpointing (full) - flip the condition above + if enable_activation_checkpointing and ac_mode is None: + utils.set_activation_checkpointing( + model, auto_wrap_policy={modules.TransformerDecoderLayer} + ) + + if self._is_rank_zero: + memory_stats = utils.get_memory_stats(device=self._device) + utils.log_memory_stats(memory_stats) + + # synchronize before training begins + torch.distributed.barrier() + + return model + + def _setup_optimizer( + self, cfg_optimizer: DictConfig, opt_state_dict: Optional[Dict[str, Any]] = None + ) -> Optimizer: + """ + Set up the optimizer. This method also handles transforing the state dict + for FSDP. + """ + optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) + + if opt_state_dict: + opt_state_dict = FSDP.optim_state_dict_to_load( + self._model, optimizer, opt_state_dict + ) + optimizer.load_state_dict(opt_state_dict) + + if self._is_rank_zero: + log.info("Optimizer is initialized.") + return optimizer + + def _setup_data( + self, + cfg_dataset: DictConfig, + shuffle: bool, + batch_size: int, + ) -> Tuple[DistributedSampler, DataLoader]: + """ + All data related setup happens here. Currently this recipe only supports the + DistributedSamplers with Map-style Datasets which fit into memory. Other samplers, + iterable datasets and streaming datasets are not supported. + """ + world_size, rank = utils.get_world_size_and_rank() + + if isinstance(cfg_dataset, ListConfig): + datasets = [ + config.instantiate(single_cfg_dataset, tokenizer=self._tokenizer) + for single_cfg_dataset in cfg_dataset + ] + ds = ConcatDataset(datasets=datasets) + packed = False + else: + ds = config.instantiate(cfg_dataset, tokenizer=self._tokenizer) + packed = cfg_dataset.get("packed", False) + + sampler = DistributedSampler( + ds, + num_replicas=world_size, + rank=rank, + shuffle=shuffle, + seed=0, + ) + dataloader = DataLoader( + dataset=ds, + batch_size=batch_size, + sampler=sampler, + collate_fn=( + partial( + utils.padded_collate, + padding_idx=self._tokenizer.pad_id, + ignore_idx=self._loss_fn.ignore_index, + ) + if not packed + else None + ), + ) + + if self._is_rank_zero: + log.info("Dataset and Sampler are initialized.") + + return sampler, dataloader + + def save_checkpoint(self, epoch: int) -> None: + """ + Save state dict to file. The recipe save_checkpoint method is responsible for + correctly creating the checkpoint dict and passing to the checkpointer. + """ + checkpoint_dict = {} + + # To prevent GPU memory from spiking during checkpoint save, + # we consolidate the full model and optim state dicts on CPU for rank 0 + with FSDP.state_dict_type( + self._model, + StateDictType.FULL_STATE_DICT, + FullStateDictConfig(offload_to_cpu=True, rank0_only=True), + FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True), + ): + cpu_state_dict = self._model.state_dict() + opt_state_dict = FSDP.optim_state_dict(self._model, self._optimizer) + + # Now that we have the model and opt state dict, create the actual checkpoint dict + # to be sent to the checkpointer and ultimately written to file + if self._is_rank_zero: + + checkpoint_dict.update({utils.MODEL_KEY: cpu_state_dict}) + + # if training is in-progress, checkpoint the optimizer state as well + if epoch + 1 < self.total_epochs: + checkpoint_dict.update( + { + utils.OPT_KEY: opt_state_dict, + utils.SEED_KEY: self.seed, + utils.EPOCHS_KEY: self.epochs_run, + utils.TOTAL_EPOCHS_KEY: self.total_epochs, + utils.MAX_STEPS_KEY: self.max_steps_per_epoch, + } + ) + + self._checkpointer.save_checkpoint( + checkpoint_dict, + epoch=epoch, + intermediate_checkpoint=(epoch + 1 < self.total_epochs), + ) + + def train(self) -> None: + """ + The core training loop. Supports training on subsets of the dataset using the + ``max_steps_per_epoch``. + """ + # clean up before training begins + utils.cleanup_before_training() + + _, rank = utils.get_world_size_and_rank() + + # zero out the gradients before starting training + self._optimizer.zero_grad() + + # Initialize tokens count and running loss (for grad accumulation) + t0 = time.perf_counter() + running_loss = 0 + num_tokens = 0 + + # self.epochs_run should be non-zero when we're resuming from a checkpoint + for curr_epoch in range(self.epochs_run, self.total_epochs): + + # Update the sampler to ensure data is correctly shuffled across epochs + # in case shuffle is True + self._sampler.set_epoch(curr_epoch) + + pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0)) + for idx, batch in enumerate(self._dataloader): + if ( + self.max_steps_per_epoch is not None + and (idx // self._gradient_accumulation_steps) + == self.max_steps_per_epoch + ): + break + + # Both are shape [b, s] + tokens, labels = batch["tokens"], batch["labels"] + # Get the attention mask and position ids from the dataset if they + # exist. Currently, only sample packing in PackedDataset returns these + mask = batch.get("mask", None) # shape [b, s, s] + input_pos = batch.get("input_pos", None) # shape [b, s] + + tokens = tokens.to(self._device) + num_tokens += tokens.numel() + labels = labels.to(self._device) + mask = mask.to(self._device) if mask is not None else None + input_pos = ( + input_pos.to(self._device) if input_pos is not None else None + ) + + logits = self._model(tokens, mask=mask, input_pos=input_pos) + # Shift so that tokens < n predict n + logits = logits[..., :-1, :].contiguous() + labels = labels[..., 1:].contiguous() + logits = logits.transpose(1, 2) + # Compute loss + loss = self._loss_fn(logits, labels) + + loss = loss / self._gradient_accumulation_steps + running_loss += loss + loss.backward() + + # Step with optimizer + if (idx + 1) % self._gradient_accumulation_steps == 0: + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) + + # Update the number of steps when the weights are updated + self.global_step += 1 + + loss_to_log = running_loss.item() + pbar.update(1) + pbar.set_description( + f"{curr_epoch+1}|{self.global_step}|Loss: {loss_to_log}" + ) + + # Log per-step metrics + if ( + self.global_step % self._log_every_n_steps == 0 + and self._is_rank_zero + ): + time_per_step = time.perf_counter() - t0 + log_dict = { + "loss": loss_to_log, + "lr": self._optimizer.param_groups[0]["lr"], + "tokens_per_second_per_gpu": num_tokens / time_per_step, + } + if self._log_peak_memory_stats: + log_dict.update(utils.get_memory_stats(device=self._device)) + self._metric_logger.log_dict( + log_dict, + step=self.global_step, + ) + + # Reset running stats for the next step + running_loss = 0 + num_tokens = 0 + t0 = time.perf_counter() + + self.epochs_run += 1 + self.save_checkpoint(epoch=curr_epoch) + + def cleanup(self) -> None: + if self._is_rank_zero: + self._metric_logger.close() + torch.distributed.destroy_process_group() + + +@config.parse +def recipe_main(cfg: DictConfig) -> None: + """ + Entry point for the recipe. + + Configurable parameters are read in the following order: + - Parameters specified in config (see available configs through ``tune ls``) + - Overwritten by arguments from the command-line + """ + if not utils.is_distributed(): + raise RuntimeError( + "Distributed finetune recipe should be run via a distributed launcher." + "If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]" + ) + + init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") + if cfg.get("fsdp_cpu_offload", False): + # Utilize all available CPU cores for intra-op parallelism. This provides ~2x + # speed up when benchmarking fused AdamW on CPU + utils.set_torch_num_threads() + + config.log_config(recipe_name="FullFinetuneRecipeDistributed", cfg=cfg) + + recipe = FullFinetuneRecipeDistributed(cfg=cfg) + recipe.setup(cfg=cfg) + recipe.train() + recipe.cleanup() + + +if __name__ == "__main__": + sys.exit(recipe_main()) diff --git a/generator/model.py b/generator/model.py index 75d2b40..b171d5e 100644 --- a/generator/model.py +++ b/generator/model.py @@ -1,16 +1,18 @@ """Lightning module for the tactic generator.""" import os +import ray +import math import torch import shutil import openai import pickle -from vllm import LLM from lean_dojo import Pos from loguru import logger import pytorch_lightning as pl from torchmetrics import Metric from abc import ABC, abstractmethod +from vllm import SamplingParams from typing import List, Dict, Any, Optional, Tuple from transformers import T5ForConditionalGeneration, AutoTokenizer @@ -247,7 +249,7 @@ def on_validation_epoch_end(self) -> None: from prover.evaluate import evaluate # Avoid circular import. - ckpt_path = f"{self.trainer.log_dir}/checkpoints/last-tmp.ckpt" + ckpt_path = f"{self.trainer.log_dir}/last-tmp.ckpt" self.trainer.save_checkpoint(ckpt_path) logger.info(f"Saved checkpoint to {ckpt_path}. Evaluating...") torch.cuda.empty_cache() @@ -528,11 +530,10 @@ def batch_generate( ] -class SyncVllmGenerator(TacticGenerator): - def __init__(self, model_path: str, num_gpus: int) -> None: - self.llm = LLM(model_path, tensor_parallel_size=num_gpus) +class VllmGenerator(TacticGenerator): + def __init__(self, vllm_actor) -> None: + self.vllm_actor = vllm_actor - @abstractmethod def generate( self, state: str, @@ -541,12 +542,16 @@ def generate( theorem_pos: Pos, num_samples: int, ) -> List[Tuple[str, float]]: - import pdb - - pdb.set_trace() - raise NotImplementedError + outputs = ray.get( + self.vllm_actor.generate.remote( + f"### State:\n{state}\n\n### Tactic:", num_samples + ) + ) + return [ + (remove_marks(x.text), math.exp(x.cumulative_logprob)) + for x in outputs[0].outputs + ] - @abstractmethod def batch_generate( self, state: List[str], @@ -555,7 +560,12 @@ def batch_generate( theorem_pos: List[Pos], num_samples: int, ) -> List[List[Tuple[str, float]]]: - import pdb - - pdb.set_trace() - raise NotImplementedError + inputs = [f"### State:\n{s}\n\n### Tactic:" for s in state] + outputs = ray.get(self.vllm_actor.generate.remote(inputs, num_samples)) + return [ + [ + (remove_marks(x.text), math.exp(x.cumulative_logprob)) + for x in oup.outputs + ] + for oup in outputs + ] diff --git a/generator/preprocess_data.py b/generator/preprocess_data.py new file mode 100644 index 0000000..51cb435 --- /dev/null +++ b/generator/preprocess_data.py @@ -0,0 +1,31 @@ +import pdb +import csv +import json +import random + +from common import format_state, format_tactic + + +def main() -> None: + pairs = [] + data_path = "../data/leandojo_benchmark_4/random/train.json" + + for thm in json.load(open(data_path)): + for tac in thm["traced_tactics"]: + if "annotated_tactic" in tac: + tactic = format_tactic(*tac["annotated_tactic"], normalize=True) + else: + tactic = format_tactic(tac["tactic"], [], normalize=True) + pairs.append({"state": format_state(tac["state_before"]), "output": tactic}) + + random.shuffle(pairs) + + with open("state_tactic_pairs.csv", "wt") as oup: + wt = csv.DictWriter(oup, fieldnames=["state", "output"]) + wt.writeheader() + for st in pairs: + wt.writerow(st) + + +if __name__ == "__main__": + main() diff --git a/generator/template.py b/generator/template.py new file mode 100644 index 0000000..aa6d227 --- /dev/null +++ b/generator/template.py @@ -0,0 +1,15 @@ +from torchtune.data import InstructTemplate +from typing import Mapping, Any, Optional, Dict + + +class StateTacticPairTemplate(InstructTemplate): + template = "### State:\n{state}\n\n### Tactic:" + # template = "[GOAL]\n{state}\n[PROOFSTEP]\n" + + @classmethod + def format( + cls, sample: Mapping[str, Any], column_map: Optional[Dict[str, str]] = None + ) -> str: + column_map = column_map or {} + key_state = column_map.get("state", "state") + return cls.template.format(state=sample[key_state]) diff --git a/prover/evaluate.py b/prover/evaluate.py index 4a1c24e..70e34d5 100644 --- a/prover/evaluate.py +++ b/prover/evaluate.py @@ -97,6 +97,7 @@ def evaluate( full_name: Optional[str] = None, name_filter: Optional[str] = None, num_theorems: Optional[int] = None, + use_vllm: bool = False, ckpt_path: Optional[str] = None, indexed_corpus_path: Optional[str] = None, tactic: Optional[str] = None, @@ -116,6 +117,7 @@ def evaluate( # Search for proofs using multiple concurrent provers. prover = DistributedProver( + use_vllm, ckpt_path, indexed_corpus_path, tactic, @@ -180,6 +182,7 @@ def main() -> None: parser.add_argument("--full-name", type=str) parser.add_argument("--name-filter", type=str) parser.add_argument("--num-theorems", type=int) + parser.add_argument("--use-vllm", action="store_true") parser.add_argument( "--ckpt_path", type=str, @@ -230,6 +233,7 @@ def main() -> None: args.full_name, args.name_filter, args.num_theorems, + args.use_vllm, args.ckpt_path, args.indexed_corpus_path, args.tactic, diff --git a/prover/proof_search.py b/prover/proof_search.py index c971bb2..8752e18 100644 --- a/prover/proof_search.py +++ b/prover/proof_search.py @@ -24,10 +24,16 @@ from dataclasses import dataclass from typing import List, Optional, Tuple from ray.util.actor_pool import ActorPool +from vllm import LLM, SamplingParams, RequestOutput from common import zip_strict from prover.search_tree import * -from generator.model import RetrievalAugmentedGenerator, FixedTacticGenerator +from generator.model import ( + TacticGenerator, + RetrievalAugmentedGenerator, + FixedTacticGenerator, + VllmGenerator, +) @dataclass(frozen=True) @@ -300,29 +306,16 @@ def check_invariants(self): @ray.remote -class CpuProver(BestFirstSearchProver): - """Ray actor for running an instance of `BestFirstSearchProver` on a CPU.""" +class ProverActor(BestFirstSearchProver): + """Ray actor for running an instance of `BestFirstSearchProver`.""" def __init__( self, - ckpt_path: Optional[str], - indexed_corpus_path: Optional[str], - tactic: Optional[str], - module: Optional[str], + tac_gen: TacticGenerator, timeout: int, num_sampled_tactics: int, debug: bool, ) -> None: - if ckpt_path is None: - tac_gen = FixedTacticGenerator(tactic, module) - else: - tac_gen = RetrievalAugmentedGenerator.load( - ckpt_path, device=torch.device("cpu"), freeze=True - ) - if tac_gen.retriever is not None: - if indexed_corpus_path is not None: - tac_gen.retriever.load_corpus(indexed_corpus_path) - tac_gen.retriever.reindex_corpus(batch_size=32) super().__init__( tac_gen, timeout, @@ -331,47 +324,41 @@ def __init__( ) -@ray.remote(num_gpus=1) -class GpuProver(BestFirstSearchProver): - """Ray actor for running an instance of `BestFirstSearchProver` on a GPU.""" - - def __init__( - self, - ckpt_path: Optional[str], - indexed_corpus_path: Optional[str], - tactic: Optional[str], - module: Optional[str], - timeout: int, - num_sampled_tactics: int, - debug: bool, - ) -> None: - if ckpt_path is None: - tac_gen = FixedTacticGenerator(tactic, module) - else: - tac_gen = RetrievalAugmentedGenerator.load( - ckpt_path, device=torch.device("cuda"), freeze=True - ) - if tac_gen.retriever is not None: - if indexed_corpus_path is not None: - tac_gen.retriever.load_corpus(indexed_corpus_path) - tac_gen.retriever.reindex_corpus(batch_size=32) - super().__init__( - tac_gen, - timeout, - num_sampled_tactics, - debug, +@ray.remote +class VllmActor: + """Ray actor for running an instance of `vllm.LLM`, which is shared by all `ProverActor` instances.""" + + def __init__(self, model_path: str) -> None: + self.num_gpus = len(ray.get_gpu_ids()) + self.model_path = model_path + + def initialize(self) -> None: + logger.info("Initializing vLLM") + # TODO: Try `--enable-prefix-caching` and other parameters in https://docs.vllm.ai/en/stable/models/engine_args.html#engine-args. + self.llm = LLM(self.model_path, tensor_parallel_size=self.num_gpus) + + def generate( + self, inputs: Union[str, List[str]], num_samples: int + ) -> List[RequestOutput]: + sampling_params = SamplingParams( + n=num_samples, temperature=0, use_beam_search=True, early_stopping=False ) + outputs = self.llm.generate(inputs, sampling_params, use_tqdm=False) + if isinstance(inputs, str): + assert len(outputs) == 1 + return outputs class DistributedProver: """A distributed prover that uses Ray to parallelize the proof search. - It is a wrapper around `CpuProver` and `GpuProver` that handles the different + It is a wrapper around `ProverActor` that handles the different devices and different number of concurrent provers. """ def __init__( self, + use_vllm: bool, ckpt_path: Optional[str], indexed_corpus_path: Optional[str], tactic: Optional[str], @@ -386,20 +373,26 @@ def __init__( assert tactic and not indexed_corpus_path else: assert not tactic and not module - self.distributed = num_workers > 1 + if ckpt_path is None: + tac_gen = FixedTacticGenerator(tactic, module) + elif use_vllm: + assert indexed_corpus_path is None + vllm_actor = VllmActor.options(num_gpus=num_gpus).remote(ckpt_path) + ray.get(vllm_actor.initialize.remote()) + tac_gen = VllmGenerator(vllm_actor) + else: + device = torch.device("cuda") if num_gpus > 0 else torch.device("cpu") + tac_gen = RetrievalAugmentedGenerator.load( + ckpt_path, device=device, freeze=True + ) + if tac_gen.retriever is not None: + assert indexed_corpus_path is not None + tac_gen.retriever.load_corpus(indexed_corpus_path) + + self.distributed = num_workers > 1 if not self.distributed: assert num_gpus <= 1 - if ckpt_path is None: - tac_gen = FixedTacticGenerator(tactic, module) - else: - device = torch.device("cuda") if num_gpus > 0 else torch.device("cpu") - tac_gen = RetrievalAugmentedGenerator.load( - ckpt_path, device=device, freeze=True - ) - if tac_gen.retriever is not None: - assert indexed_corpus_path is not None - tac_gen.retriever.load_corpus(indexed_corpus_path) self.prover = BestFirstSearchProver( tac_gen, timeout, num_sampled_tactics, debug ) @@ -407,13 +400,14 @@ def __init__( if num_gpus >= 1: logger.info(f"Launching {num_workers} workers with {num_gpus} GPUs.") - num_gpus_per_worker = num_gpus / num_workers + if use_vllm: + # GPUs are managed by `VllmActor`. + num_gpus_per_worker = 0 + else: + num_gpus_per_worker = num_gpus / num_workers provers = [ - GpuProver.options(num_gpus=num_gpus_per_worker).remote( - ckpt_path, - indexed_corpus_path, - tactic, - module, + ProverActor.options(num_gpus=num_gpus_per_worker).remote( + tac_gen, timeout=timeout, num_sampled_tactics=num_sampled_tactics, debug=debug, @@ -423,11 +417,8 @@ def __init__( else: logger.info(f"Launching {num_workers} CPU workers.") provers = [ - CpuProver.remote( - ckpt_path, - indexed_corpus_path, - tactic, - module, + ProverActor.remote( + tac_gen, timeout=timeout, num_sampled_tactics=num_sampled_tactics, debug=debug,