forked from rui-ye/OpenFedLLM
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[rui/training] include instruction tuning and value alignment
- training of federated instruction tuning. i.e., main_sft.py - training of federated value alignment. i.e., main_dpo.py - federated learning algorithms in federated_learning/ directory - first version of README
- Loading branch information
Showing
15 changed files
with
969 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
from dataclasses import dataclass, field, asdict | ||
from typing import Optional | ||
from transformers import HfArgumentParser, TrainingArguments, BitsAndBytesConfig | ||
from peft import LoraConfig | ||
import os | ||
import json | ||
from accelerate import Accelerator | ||
import torch | ||
from datetime import datetime, timedelta | ||
|
||
|
||
# Define and parse arguments. | ||
@dataclass | ||
class FedArguments: | ||
fed_alg: Optional[str] = field(default="fedavg", metadata={"help": "the algorithm to use"}) | ||
num_rounds: Optional[int] = field(default=500, metadata={"help": "the number of rounds"}) | ||
num_clients: Optional[int] = field(default=2, metadata={"help": "the number of clients"}) | ||
sample_clients: Optional[int] = field(default=2, metadata={"help": "the number of clients to sample"}) | ||
split_strategy: Optional[str] = field(default="iid", metadata={"help": "the split strategy"}) | ||
prox_mu: Optional[float] = field(default=0.01, metadata={"help": "the mu parameter of FedProx"}) | ||
fedopt_tau: Optional[float] = field(default=1e-3, metadata={"help": "the tau parameter of FedAdagrad, FedYogi and FedAdam"}) | ||
fedopt_eta: Optional[float] = field(default=1e-3, metadata={"help": "the global learning rate parameter of FedAdagrad, FedYogi and FedAdam"}) | ||
fedopt_beta1: Optional[float] = field(default=0.9, metadata={"help": "the beta1 parameter of FedYogi and FedAdam"}) | ||
fedopt_beta2: Optional[float] = field(default=0.99, metadata={"help": "the beta2 parameter of FedYogi and FedAdam"}) | ||
|
||
@dataclass | ||
class ScriptArguments: | ||
|
||
model_name_or_path: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the model name"}) | ||
dataset_name: Optional[str] = field( | ||
default="lucasmccabe-lmi/CodeAlpaca-20k", metadata={"help": "the dataset name"} | ||
) | ||
log_with: Optional[str] = field(default="none", metadata={"help": "use 'wandb' to log with wandb"}) | ||
learning_rate: Optional[float] = field(default=2e-5, metadata={"help": "the learning rate"}) # vicuna and alpaca use 2e-5 | ||
batch_size: Optional[int] = field(default=16, metadata={"help": "the batch size"}) | ||
seq_length: Optional[int] = field(default=512, metadata={"help": "Input sequence length"}) | ||
gradient_accumulation_steps: Optional[int] = field( | ||
default=1, metadata={"help": "the number of gradient accumulation steps"} | ||
) | ||
load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"}) | ||
load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"}) | ||
use_peft: Optional[bool] = field(default=False, metadata={"help": "Wether to use PEFT or not to train adapters"}) | ||
trust_remote_code: Optional[bool] = field(default=False, metadata={"help": "Enable `trust_remote_code`"}) | ||
output_dir: Optional[str] = field(default="output", metadata={"help": "the output directory"}) | ||
peft_lora_r: Optional[int] = field(default=8, metadata={"help": "the r parameter of the LoRA adapters"}) | ||
peft_lora_alpha: Optional[int] = field(default=16, metadata={"help": "the alpha parameter of the LoRA adapters"}) | ||
logging_steps: Optional[int] = field(default=100, metadata={"help": "the number of logging steps"}) | ||
use_auth_token: Optional[bool] = field(default=False, metadata={"help": "Use HF auth token to access the model"}) # token and use_auth_token cannot be used together | ||
num_train_epochs: Optional[int] = field(default=3, metadata={"help": "the number of training epochs"}) | ||
max_steps: Optional[int] = field(default=10, metadata={"help": "the number of training steps"}) | ||
save_steps: Optional[int] = field( | ||
default=1000, metadata={"help": "Number of updates steps before two checkpoint saves"} | ||
) | ||
save_total_limit: Optional[int] = field(default=10, metadata={"help": "Limits total number of checkpoints."}) | ||
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the model to HF Hub"}) | ||
hub_model_id: Optional[str] = field(default=None, metadata={"help": "The name of the model on HF Hub"}) | ||
gradient_checkpointing: Optional[bool] = field(default=True, metadata={"help": "Enable gradient checkpointing"}) | ||
template: Optional[str] = field(default="alpaca", metadata={"help": "the template to use"}) | ||
seed: Optional[int] = field(default=2023, metadata={"help": "the seed to use"}) | ||
# auth_token_path: Optional[str] = field(default="utils/hf_token.yaml", metadata={"help": "the path to the auth token"}) | ||
dpo_beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter of DPO"}) | ||
dataset_sample: Optional[int] = field(default=20000, metadata={"help": "the number of samples to use from the dataset"}) | ||
local_data_dir: Optional[str] = field(default=None, metadata={"help": "the local data directory if you want to use downloaded data"}) | ||
|
||
parser = HfArgumentParser((ScriptArguments, FedArguments)) | ||
script_args, fed_args = parser.parse_args_into_dataclasses() | ||
|
||
# ===== Define the LoraConfig ===== | ||
if script_args.use_peft: | ||
peft_config = LoraConfig( | ||
r=script_args.peft_lora_r, | ||
lora_alpha=script_args.peft_lora_alpha, | ||
lora_dropout=0.05, | ||
bias="none", | ||
task_type="CAUSAL_LM", | ||
) | ||
else: | ||
peft_config = None | ||
|
||
def get_config(): | ||
return script_args, fed_args, peft_config | ||
|
||
# ===== Define the training arguments ===== | ||
def get_training_args(script_args, new_lr): | ||
training_args = TrainingArguments( | ||
output_dir=script_args.output_dir, | ||
per_device_train_batch_size=script_args.batch_size, | ||
gradient_accumulation_steps=script_args.gradient_accumulation_steps, | ||
learning_rate=new_lr, | ||
logging_steps=script_args.logging_steps, | ||
num_train_epochs=script_args.num_train_epochs, | ||
max_steps=script_args.max_steps, | ||
report_to=script_args.log_with, | ||
save_steps=script_args.save_steps, | ||
save_total_limit=script_args.save_total_limit, | ||
push_to_hub=script_args.push_to_hub, | ||
hub_model_id=script_args.hub_model_id, | ||
gradient_checkpointing=script_args.gradient_checkpointing, | ||
lr_scheduler_type="constant", | ||
) | ||
return training_args | ||
|
||
def get_model_config(script_args): | ||
if script_args.load_in_8bit and script_args.load_in_4bit: | ||
raise ValueError("You can't load the model in 8 bits and 4 bits at the same time") | ||
elif script_args.load_in_8bit or script_args.load_in_4bit: | ||
quantization_config = BitsAndBytesConfig( | ||
load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit | ||
) | ||
# Copy the model to each device | ||
device_map = {"": Accelerator().local_process_index} | ||
torch_dtype = torch.bfloat16 | ||
else: | ||
device_map = None | ||
quantization_config = None | ||
torch_dtype = None | ||
return device_map, quantization_config, torch_dtype | ||
|
||
def save_config(script_args, fed_args): | ||
now_time = (datetime.now()).strftime("%Y%m%d%H%M%S") | ||
dataset_name_split = os.path.basename(script_args.dataset_name) | ||
output_dir = f"{script_args.output_dir}/{dataset_name_split}_{script_args.dataset_sample}_{fed_args.fed_alg}_c{fed_args.num_clients}s{fed_args.sample_clients}_i{script_args.max_steps}_b{script_args.batch_size}a{script_args.gradient_accumulation_steps}_l{script_args.seq_length}_r{script_args.peft_lora_r}a{script_args.peft_lora_alpha}_{now_time}" | ||
while True: | ||
if not os.path.exists(output_dir): | ||
os.mkdir(output_dir) | ||
break | ||
else: | ||
now_time = (datetime.now() + timedelta(seconds=1)).strftime("%Y%m%d%H%M%S") | ||
output_dir = f"{script_args.output_dir}/{dataset_name_split}_{fed_args.fed_alg}_c{fed_args.num_clients}s{fed_args.sample_clients}_i{script_args.max_steps}_b{script_args.batch_size}a{script_args.gradient_accumulation_steps}_l{script_args.seq_length}_{now_time}" | ||
|
||
script_args.output_dir = output_dir | ||
with open(os.path.join(script_args.output_dir, "args.json"), "w") as f: | ||
combined_dict = { | ||
"script_args": asdict(script_args), | ||
"fed_args": asdict(fed_args), | ||
} | ||
json.dump(combined_dict, f, indent=4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .fed_local_sft import get_fed_local_sft_trainer, SCAFFOLD_Callback | ||
from .fed_local_dpo import get_fed_local_dpo_trainer | ||
from .fed_global import get_clients_this_round, global_aggregate | ||
from .split_dataset import split_dataset, get_dataset_this_round | ||
from .fed_utils import get_proxy_dict, get_auxiliary_dict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import random | ||
import torch | ||
|
||
def get_clients_this_round(fed_args, round): | ||
if (fed_args.fed_alg).startswith('local'): | ||
clients_this_round = [int((fed_args.fed_alg)[-1])] | ||
else: | ||
if fed_args.num_clients < fed_args.sample_clients: | ||
clients_this_round = list(range(fed_args.num_clients)) | ||
else: | ||
random.seed(round) | ||
clients_this_round = sorted(random.sample(range(fed_args.num_clients), fed_args.sample_clients)) | ||
return clients_this_round | ||
|
||
def global_aggregate(fed_args, global_dict, local_dict_list, sample_num_list, clients_this_round, round_idx, proxy_dict=None, opt_proxy_dict=None, auxiliary_info=None): | ||
sample_this_round = sum([sample_num_list[client] for client in clients_this_round]) | ||
global_auxiliary = None | ||
|
||
if fed_args.fed_alg == 'scaffold': | ||
for key in global_dict.keys(): | ||
global_dict[key] = sum([local_dict_list[client][key] * sample_num_list[client] / sample_this_round for client in clients_this_round]) | ||
global_auxiliary, auxiliary_delta_dict = auxiliary_info | ||
for key in global_auxiliary.keys(): | ||
delta_auxiliary = sum([auxiliary_delta_dict[client][key] for client in clients_this_round]) | ||
global_auxiliary[key] += delta_auxiliary / fed_args.num_clients | ||
|
||
elif fed_args.fed_alg == 'fedavgm': | ||
# Momentum-based FedAvg | ||
for key in global_dict.keys(): | ||
delta_w = sum([(local_dict_list[client][key] - global_dict[key]) * sample_num_list[client] / sample_this_round for client in clients_this_round]) | ||
proxy_dict[key] = fed_args.fedopt_beta1 * proxy_dict[key] + (1 - fed_args.fedopt_beta1) * delta_w if round_idx > 0 else delta_w | ||
global_dict[key] = global_dict[key] + proxy_dict[key] | ||
|
||
elif fed_args.fed_alg == 'fedadagrad': | ||
for key, param in opt_proxy_dict.items(): | ||
delta_w = sum([(local_dict_list[client][key] - global_dict[key]) for client in clients_this_round]) / len(clients_this_round) | ||
# In paper 'adaptive federated optimization', momentum is not used | ||
proxy_dict[key] = delta_w | ||
opt_proxy_dict[key] = param + torch.square(proxy_dict[key]) | ||
global_dict[key] += fed_args.fedopt_eta * torch.div(proxy_dict[key], torch.sqrt(opt_proxy_dict[key])+fed_args.fedopt_tau) | ||
|
||
elif fed_args.fed_alg == 'fedyogi': | ||
for key, param in opt_proxy_dict.items(): | ||
delta_w = sum([(local_dict_list[client][key] - global_dict[key]) for client in clients_this_round]) / len(clients_this_round) | ||
proxy_dict[key] = fed_args.fedopt_beta1 * proxy_dict[key] + (1 - fed_args.fedopt_beta1) * delta_w if round_idx > 0 else delta_w | ||
delta_square = torch.square(proxy_dict[key]) | ||
opt_proxy_dict[key] = param - (1-fed_args.fedopt_beta2)*delta_square*torch.sign(param - delta_square) | ||
global_dict[key] += fed_args.fedopt_eta * torch.div(proxy_dict[key], torch.sqrt(opt_proxy_dict[key])+fed_args.fedopt_tau) | ||
|
||
elif fed_args.fed_alg == 'fedadam': | ||
for key, param in opt_proxy_dict.items(): | ||
delta_w = sum([(local_dict_list[client][key] - global_dict[key]) for client in clients_this_round]) / len(clients_this_round) | ||
proxy_dict[key] = fed_args.fedopt_beta1 * proxy_dict[key] + (1 - fed_args.fedopt_beta1) * delta_w if round_idx > 0 else delta_w | ||
opt_proxy_dict[key] = fed_args.fedopt_beta2*param + (1-fed_args.fedopt_beta2)*torch.square(proxy_dict[key]) | ||
global_dict[key] += fed_args.fedopt_eta * torch.div(proxy_dict[key], torch.sqrt(opt_proxy_dict[key])+fed_args.fedopt_tau) | ||
|
||
else: # Normal dataset-size-based aggregation | ||
for key in global_dict.keys(): | ||
global_dict[key] = sum([local_dict_list[client][key] * sample_num_list[client] / sample_this_round for client in clients_this_round]) | ||
|
||
return global_dict, global_auxiliary |
Oops, something went wrong.