From 7c4842837fe644ab00555b27e64cce932c7e50c9 Mon Sep 17 00:00:00 2001 From: Remi Delacourt Date: Sun, 21 Jul 2024 22:29:17 +0000 Subject: [PATCH] Fix for finetuning --- inference/python/peft_demo/demo_class.py | 6 ++- python/flexflow/serve/serve.py | 57 ++++++++++++------------ 2 files changed, 34 insertions(+), 29 deletions(-) diff --git a/inference/python/peft_demo/demo_class.py b/inference/python/peft_demo/demo_class.py index 4b5448310a..dc9f9051f5 100644 --- a/inference/python/peft_demo/demo_class.py +++ b/inference/python/peft_demo/demo_class.py @@ -48,6 +48,10 @@ def initialize_flexflow(self): self.llm.cache_path, self.configs.finetuning_peft_model_id, trainable=True, + init_lora_weights=True, + rank=16, + lora_alpha=16.0, + target_modules = ["down_proj"], base_model_name_or_path=self.configs.base_model, optimizer_type=ff.OptimizerType.OPTIMIZER_TYPE_SGD, optimizer_kwargs={ @@ -159,7 +163,7 @@ def main(): model_configs = { "base_model": "meta-llama/Meta-Llama-3-8B", "inference_peft_model_id": "goliaro/llama-3-8b-lora", - "finetuning_peft_model_id": "goliaro/llama-3-8b-lora", + "finetuning_peft_model_id": "flechman/llama-3-8b-lora-dolly", "cache_path": os.environ.get("FF_CACHE_PATH", ""), "refresh_cache": False, "full_precision": True, diff --git a/python/flexflow/serve/serve.py b/python/flexflow/serve/serve.py index c71429082b..3d26a1840e 100644 --- a/python/flexflow/serve/serve.py +++ b/python/flexflow/serve/serve.py @@ -285,35 +285,36 @@ def convert_peft_model(hf_peft_model, peft_type, weights_path): def download_peft_weights(): for ff_peft_config, peft_dict in self.pefts.items(): - peft_config = peft_dict["peft_config"] - peft_type = peft_dict["peft_type"] - peft_model_id = ff_peft_config.peft_model_id - - weights_path = get_weights_path(peft_model_id) - refresh_cache_if_needed(peft_model_id) - ff_revision, ff_revision_file, latest_revision = self.__get_revision_hashes( - peft_model_id, weights_path - ) - - if ff_revision != latest_revision: - print( - f"'{peft_model_id}' local model weights need updating! Downloading/converting new weights now..." + if not ff_peft_config.init_lora_weights: + peft_config = peft_dict["peft_config"] + peft_type = peft_dict["peft_type"] + peft_model_id = ff_peft_config.peft_model_id + + weights_path = get_weights_path(peft_model_id) + refresh_cache_if_needed(peft_model_id) + ff_revision, ff_revision_file, latest_revision = self.__get_revision_hashes( + peft_model_id, weights_path ) - hf_model = get_hf_llm(peft_model_id) - hf_peft_model = PeftModel.from_pretrained( - hf_model, peft_model_id, config=peft_config - ) - # Convert the model to FlexFlow format - convert_peft_model(hf_peft_model, peft_type, weights_path) - # Save new revision hash to file - with open(ff_revision_file, "w+") as f: - f.write(latest_revision) - print(f"Done converting the weights for model {peft_model_id}") - # Deallocate hf model - del hf_peft_model - del hf_model - gc.collect() - torch.cuda.empty_cache() + + if ff_revision != latest_revision: + print( + f"'{peft_model_id}' local model weights need updating! Downloading/converting new weights now..." + ) + hf_model = get_hf_llm(peft_model_id) + hf_peft_model = PeftModel.from_pretrained( + hf_model, peft_model_id, config=peft_config + ) + # Convert the model to FlexFlow format + convert_peft_model(hf_peft_model, peft_type, weights_path) + # Save new revision hash to file + with open(ff_revision_file, "w+") as f: + f.write(latest_revision) + print(f"Done converting the weights for model {peft_model_id}") + # Deallocate hf model + del hf_peft_model + del hf_model + gc.collect() + torch.cuda.empty_cache() self.weights_path = get_weights_path(self.model_name) download_llm_weights()