Skip to content

Commit

Permalink
Merge pull request rui-ye#19 from rui-ye/training
Browse files Browse the repository at this point in the history
[rui/training] implement QLoRA for better 4bit quantization
  • Loading branch information
rui-ye authored Mar 30, 2024
2 parents d45644d + d9f8087 commit 859865a
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 4 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
training_scripts/*
!training_scripts/run_sft.sh
!training_scripts/run_dpo.sh

evaluation/open_ended/data/*/model_answer/*
evaluation/open_ended/data/*/model_judgment/*

Expand Down
15 changes: 13 additions & 2 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class FedArguments:
fedopt_eta: Optional[float] = field(default=1e-3, metadata={"help": "the global learning rate parameter of FedAdagrad, FedYogi and FedAdam"})
fedopt_beta1: Optional[float] = field(default=0.9, metadata={"help": "the beta1 parameter of FedYogi and FedAdam"})
fedopt_beta2: Optional[float] = field(default=0.99, metadata={"help": "the beta2 parameter of FedYogi and FedAdam"})
save_model_freq: Optional[int] = field(default=50, metadata={"help": "the frequency to save the model. 50 means save every 50 rounds"})

@dataclass
class ScriptArguments:
Expand Down Expand Up @@ -102,9 +103,19 @@ def get_training_args(script_args, new_lr):
def get_model_config(script_args):
if script_args.load_in_8bit and script_args.load_in_4bit:
raise ValueError("You can't load the model in 8 bits and 4 bits at the same time")
elif script_args.load_in_8bit or script_args.load_in_4bit:
elif script_args.load_in_8bit:
quantization_config = BitsAndBytesConfig(
load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit
load_in_8bit=script_args.load_in_8bit
)
# Copy the model to each device
device_map = {"": Accelerator().local_process_index}
torch_dtype = torch.bfloat16
elif script_args.load_in_4bit:
quantization_config = BitsAndBytesConfig(
load_in_4bit=script_args.load_in_4bit,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
)
# Copy the model to each device
device_map = {"": Accelerator().local_process_index}
Expand Down
7 changes: 6 additions & 1 deletion main_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

model.config.use_cache = False # silence the warnings. Please re-enable for inference!

if training_args.gradient_checkpointing:
model.enable_input_require_grads()

# ===== Define the global and local models =====
global_dict = copy.deepcopy(get_peft_model_state_dict(model))
local_dict_list = [copy.deepcopy(global_dict) for i in range(fed_args.num_clients)]
Expand Down Expand Up @@ -115,7 +120,7 @@
set_peft_model_state_dict(model, global_dict) # update global model

# ===== Save the model =====
if (round+1) % 50 == 0:
if (round+1) % fed_args.save_model_freq == 0:
trainer.save_model(os.path.join(script_args.output_dir, f"checkpoint-{round+1}"))

np.save(os.path.join(script_args.output_dir, "training_loss.npy"), np.array(training_loss))
7 changes: 6 additions & 1 deletion main_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

model.config.use_cache = False # silence the warnings. Please re-enable for inference!

if training_args.gradient_checkpointing:
model.enable_input_require_grads()

# ===== Define the global and local models =====
global_dict = copy.deepcopy(get_peft_model_state_dict(model))
local_dict_list = [copy.deepcopy(global_dict) for i in range(fed_args.num_clients)]
Expand Down Expand Up @@ -114,7 +119,7 @@
set_peft_model_state_dict(model, global_dict) # Update global model

# ===== Save the model =====
if (round+1) % 50 == 0:
if (round+1) % fed_args.save_model_freq == 0:
trainer.save_model(os.path.join(script_args.output_dir, f"checkpoint-{round+1}"))

np.save(os.path.join(script_args.output_dir, "training_loss.npy"), np.array(training_loss))

0 comments on commit 859865a

Please sign in to comment.