Skip to content

Commit

Permalink
[rui/training] include instruction tuning and value alignment
Browse files Browse the repository at this point in the history
- training of federated instruction tuning. i.e., main_sft.py
- training of federated value alignment. i.e., main_dpo.py
- federated learning algorithms in federated_learning/ directory
- first version of README
  • Loading branch information
rui-ye committed Feb 4, 2024
1 parent 515090e commit 9ee5d04
Show file tree
Hide file tree
Showing 15 changed files with 969 additions and 6 deletions.
69 changes: 69 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,77 @@

OpenFedLLM is an open-source research-use codebase for training large lanugage models via federated learning.

OpenFedLLM includes the following key features:
- 7 federated learning algorithms (e.g., FedAvg, FedProx, SCAFFOLD, FedAvgM).
- 2 LLM training algorithms, inckuding instruction tuning (i.e., SFT) and value alignment (i.e., DPO).
- 30+ evaluation metrics covering general capabilities, medical QA, financial QA, code generation, and math solving.

## Setup

Clone the repo and install the required packages.
```
git clone https://github.com/rui-ye/OpenFedLLM.git
cd OpenFedLLM
conda create -n fedllm python=3.10
conda activate fedllm
pip install -r requirements.txt
```

## Training

We provide training scripts under `training_scripts/`.

### Federated Instruction Tuning

The training script is in `training_scripts/run_sft.sh`.

```
CUDA_VISIBLE_DEVICES=1 python main_sft.py \
--model_name_or_path "meta-llama/Llama-2-7b-hf" \
--dataset_name "vicgalle/alpaca-gpt4" \
--dataset_sample 20000 \
--fed_alg "fedavg" \
--num_clients 20 \
--sample_clients 2 \
--max_steps 10 \
--num_rounds 200 \
--batch_size 16 \
--gradient_accumulation_steps 1 \
--seq_length 512 \
--peft_lora_r 32 \
--peft_lora_alpha 64 \
--use_peft \
--load_in_8bit \
--output_dir "./output" \
--template "alpaca" \
```

Key arguments:

- model_name_or_path: the name or local location of your base model
- template: template for chatting. Define your own template in `utils/template.py`.
- dataset_name: the name of dataset. You may modify `utils/process_dataset.py` if your interested dataset has not been supported.
- dataset_sample: needed if you want to sample a specific number of samples from the original dataset.
- fed_alg: the name of federated learning algorithm
- num_clients/sample_clients: `num_clients` clients in total, `sample_clients` clients for each round
- max_steps: the number of model update steps for one client at each round.

### Federated Value Alignment

The training script is in `training_scripts/run_dpo.sh`.

```
python main_dpo.py --template "vicuna_v1.1"
```

Note that the main difference between the usage of `main_sft.py` and `main_dpo.py` lies on the `template` argument. We plan to make them consistent in the future.
- For SFT, templates are defined in `utils/template.py`
- For DPO, templates are defined in `utils/conversation.py`

## Evaluation

Evaluation codes are put in `evaluation/` directory. Most of our evaluations follow existing high-incluence open-source repos. Please refer to each sub-directory for the corresponding detailed README and running script.

For example, `evaluation/open_ended/` include open-ended evaluations on three benchmarks, covering MT-Bench, Vicuna Bench, and AdvBench; see [README.md](evaluation/open_ended/README.md).

## Citation
137 changes: 137 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from dataclasses import dataclass, field, asdict
from typing import Optional
from transformers import HfArgumentParser, TrainingArguments, BitsAndBytesConfig
from peft import LoraConfig
import os
import json
from accelerate import Accelerator
import torch
from datetime import datetime, timedelta


# Define and parse arguments.
@dataclass
class FedArguments:
fed_alg: Optional[str] = field(default="fedavg", metadata={"help": "the algorithm to use"})
num_rounds: Optional[int] = field(default=500, metadata={"help": "the number of rounds"})
num_clients: Optional[int] = field(default=2, metadata={"help": "the number of clients"})
sample_clients: Optional[int] = field(default=2, metadata={"help": "the number of clients to sample"})
split_strategy: Optional[str] = field(default="iid", metadata={"help": "the split strategy"})
prox_mu: Optional[float] = field(default=0.01, metadata={"help": "the mu parameter of FedProx"})
fedopt_tau: Optional[float] = field(default=1e-3, metadata={"help": "the tau parameter of FedAdagrad, FedYogi and FedAdam"})
fedopt_eta: Optional[float] = field(default=1e-3, metadata={"help": "the global learning rate parameter of FedAdagrad, FedYogi and FedAdam"})
fedopt_beta1: Optional[float] = field(default=0.9, metadata={"help": "the beta1 parameter of FedYogi and FedAdam"})
fedopt_beta2: Optional[float] = field(default=0.99, metadata={"help": "the beta2 parameter of FedYogi and FedAdam"})

@dataclass
class ScriptArguments:

model_name_or_path: Optional[str] = field(default="meta-llama/Llama-2-7b-hf", metadata={"help": "the model name"})
dataset_name: Optional[str] = field(
default="lucasmccabe-lmi/CodeAlpaca-20k", metadata={"help": "the dataset name"}
)
log_with: Optional[str] = field(default="none", metadata={"help": "use 'wandb' to log with wandb"})
learning_rate: Optional[float] = field(default=2e-5, metadata={"help": "the learning rate"}) # vicuna and alpaca use 2e-5
batch_size: Optional[int] = field(default=16, metadata={"help": "the batch size"})
seq_length: Optional[int] = field(default=512, metadata={"help": "Input sequence length"})
gradient_accumulation_steps: Optional[int] = field(
default=1, metadata={"help": "the number of gradient accumulation steps"}
)
load_in_8bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 8 bits precision"})
load_in_4bit: Optional[bool] = field(default=False, metadata={"help": "load the model in 4 bits precision"})
use_peft: Optional[bool] = field(default=False, metadata={"help": "Wether to use PEFT or not to train adapters"})
trust_remote_code: Optional[bool] = field(default=False, metadata={"help": "Enable `trust_remote_code`"})
output_dir: Optional[str] = field(default="output", metadata={"help": "the output directory"})
peft_lora_r: Optional[int] = field(default=8, metadata={"help": "the r parameter of the LoRA adapters"})
peft_lora_alpha: Optional[int] = field(default=16, metadata={"help": "the alpha parameter of the LoRA adapters"})
logging_steps: Optional[int] = field(default=100, metadata={"help": "the number of logging steps"})
use_auth_token: Optional[bool] = field(default=False, metadata={"help": "Use HF auth token to access the model"}) # token and use_auth_token cannot be used together
num_train_epochs: Optional[int] = field(default=3, metadata={"help": "the number of training epochs"})
max_steps: Optional[int] = field(default=10, metadata={"help": "the number of training steps"})
save_steps: Optional[int] = field(
default=1000, metadata={"help": "Number of updates steps before two checkpoint saves"}
)
save_total_limit: Optional[int] = field(default=10, metadata={"help": "Limits total number of checkpoints."})
push_to_hub: Optional[bool] = field(default=False, metadata={"help": "Push the model to HF Hub"})
hub_model_id: Optional[str] = field(default=None, metadata={"help": "The name of the model on HF Hub"})
gradient_checkpointing: Optional[bool] = field(default=True, metadata={"help": "Enable gradient checkpointing"})
template: Optional[str] = field(default="alpaca", metadata={"help": "the template to use"})
seed: Optional[int] = field(default=2023, metadata={"help": "the seed to use"})
# auth_token_path: Optional[str] = field(default="utils/hf_token.yaml", metadata={"help": "the path to the auth token"})
dpo_beta: Optional[float] = field(default=0.1, metadata={"help": "the beta parameter of DPO"})
dataset_sample: Optional[int] = field(default=20000, metadata={"help": "the number of samples to use from the dataset"})
local_data_dir: Optional[str] = field(default=None, metadata={"help": "the local data directory if you want to use downloaded data"})

parser = HfArgumentParser((ScriptArguments, FedArguments))
script_args, fed_args = parser.parse_args_into_dataclasses()

# ===== Define the LoraConfig =====
if script_args.use_peft:
peft_config = LoraConfig(
r=script_args.peft_lora_r,
lora_alpha=script_args.peft_lora_alpha,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
else:
peft_config = None

def get_config():
return script_args, fed_args, peft_config

# ===== Define the training arguments =====
def get_training_args(script_args, new_lr):
training_args = TrainingArguments(
output_dir=script_args.output_dir,
per_device_train_batch_size=script_args.batch_size,
gradient_accumulation_steps=script_args.gradient_accumulation_steps,
learning_rate=new_lr,
logging_steps=script_args.logging_steps,
num_train_epochs=script_args.num_train_epochs,
max_steps=script_args.max_steps,
report_to=script_args.log_with,
save_steps=script_args.save_steps,
save_total_limit=script_args.save_total_limit,
push_to_hub=script_args.push_to_hub,
hub_model_id=script_args.hub_model_id,
gradient_checkpointing=script_args.gradient_checkpointing,
lr_scheduler_type="constant",
)
return training_args

def get_model_config(script_args):
if script_args.load_in_8bit and script_args.load_in_4bit:
raise ValueError("You can't load the model in 8 bits and 4 bits at the same time")
elif script_args.load_in_8bit or script_args.load_in_4bit:
quantization_config = BitsAndBytesConfig(
load_in_8bit=script_args.load_in_8bit, load_in_4bit=script_args.load_in_4bit
)
# Copy the model to each device
device_map = {"": Accelerator().local_process_index}
torch_dtype = torch.bfloat16
else:
device_map = None
quantization_config = None
torch_dtype = None
return device_map, quantization_config, torch_dtype

def save_config(script_args, fed_args):
now_time = (datetime.now()).strftime("%Y%m%d%H%M%S")
dataset_name_split = os.path.basename(script_args.dataset_name)
output_dir = f"{script_args.output_dir}/{dataset_name_split}_{script_args.dataset_sample}_{fed_args.fed_alg}_c{fed_args.num_clients}s{fed_args.sample_clients}_i{script_args.max_steps}_b{script_args.batch_size}a{script_args.gradient_accumulation_steps}_l{script_args.seq_length}_r{script_args.peft_lora_r}a{script_args.peft_lora_alpha}_{now_time}"
while True:
if not os.path.exists(output_dir):
os.mkdir(output_dir)
break
else:
now_time = (datetime.now() + timedelta(seconds=1)).strftime("%Y%m%d%H%M%S")
output_dir = f"{script_args.output_dir}/{dataset_name_split}_{fed_args.fed_alg}_c{fed_args.num_clients}s{fed_args.sample_clients}_i{script_args.max_steps}_b{script_args.batch_size}a{script_args.gradient_accumulation_steps}_l{script_args.seq_length}_{now_time}"

script_args.output_dir = output_dir
with open(os.path.join(script_args.output_dir, "args.json"), "w") as f:
combined_dict = {
"script_args": asdict(script_args),
"fed_args": asdict(fed_args),
}
json.dump(combined_dict, f, indent=4)
36 changes: 34 additions & 2 deletions evaluation/open_ended/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
# Open-Ended LLM Judgement

- (Optional) You could firstly run `utils/merge_lora.py` to merge LORA to the base model.
- We currently support three benchmarks
- MT-Bench
- Vicuna benchmark
- Advbench

You could firstly run `utils/merge_lora.py` to merge LORA to the base model.

## MT-Bench

Expand Down Expand Up @@ -84,4 +89,31 @@ The judgments will be saved to `data/advbench/model_judgment/[JUDGER]_[MODEL-ID]

```
python show_results_vicuna.py --eval_list [EVAL-LIST-ID]
```
```

## Citation

For MT-Bench and Vicuna Benchmark:
```
@misc{zheng2023judging,
title={Judging LLM-as-a-judge with MT-Bench and Chatbot Arena},
author={Lianmin Zheng and Wei-Lin Chiang and Ying Sheng and Siyuan Zhuang and Zhanghao Wu and Yonghao Zhuang and Zi Lin and Zhuohan Li and Dacheng Li and Eric. P Xing and Hao Zhang and Joseph E. Gonzalez and Ion Stoica},
year={2023},
eprint={2306.05685},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```

For Advbench:
```
@misc{zou2023universal,
title={Universal and Transferable Adversarial Attacks on Aligned Language Models},
author={Andy Zou and Zifan Wang and J. Zico Kolter and Matt Fredrikson},
year={2023},
eprint={2307.15043},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
```

5 changes: 5 additions & 0 deletions federated_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .fed_local_sft import get_fed_local_sft_trainer, SCAFFOLD_Callback
from .fed_local_dpo import get_fed_local_dpo_trainer
from .fed_global import get_clients_this_round, global_aggregate
from .split_dataset import split_dataset, get_dataset_this_round
from .fed_utils import get_proxy_dict, get_auxiliary_dict
61 changes: 61 additions & 0 deletions federated_learning/fed_global.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import random
import torch

def get_clients_this_round(fed_args, round):
if (fed_args.fed_alg).startswith('local'):
clients_this_round = [int((fed_args.fed_alg)[-1])]
else:
if fed_args.num_clients < fed_args.sample_clients:
clients_this_round = list(range(fed_args.num_clients))
else:
random.seed(round)
clients_this_round = sorted(random.sample(range(fed_args.num_clients), fed_args.sample_clients))
return clients_this_round

def global_aggregate(fed_args, global_dict, local_dict_list, sample_num_list, clients_this_round, round_idx, proxy_dict=None, opt_proxy_dict=None, auxiliary_info=None):
sample_this_round = sum([sample_num_list[client] for client in clients_this_round])
global_auxiliary = None

if fed_args.fed_alg == 'scaffold':
for key in global_dict.keys():
global_dict[key] = sum([local_dict_list[client][key] * sample_num_list[client] / sample_this_round for client in clients_this_round])
global_auxiliary, auxiliary_delta_dict = auxiliary_info
for key in global_auxiliary.keys():
delta_auxiliary = sum([auxiliary_delta_dict[client][key] for client in clients_this_round])
global_auxiliary[key] += delta_auxiliary / fed_args.num_clients

elif fed_args.fed_alg == 'fedavgm':
# Momentum-based FedAvg
for key in global_dict.keys():
delta_w = sum([(local_dict_list[client][key] - global_dict[key]) * sample_num_list[client] / sample_this_round for client in clients_this_round])
proxy_dict[key] = fed_args.fedopt_beta1 * proxy_dict[key] + (1 - fed_args.fedopt_beta1) * delta_w if round_idx > 0 else delta_w
global_dict[key] = global_dict[key] + proxy_dict[key]

elif fed_args.fed_alg == 'fedadagrad':
for key, param in opt_proxy_dict.items():
delta_w = sum([(local_dict_list[client][key] - global_dict[key]) for client in clients_this_round]) / len(clients_this_round)
# In paper 'adaptive federated optimization', momentum is not used
proxy_dict[key] = delta_w
opt_proxy_dict[key] = param + torch.square(proxy_dict[key])
global_dict[key] += fed_args.fedopt_eta * torch.div(proxy_dict[key], torch.sqrt(opt_proxy_dict[key])+fed_args.fedopt_tau)

elif fed_args.fed_alg == 'fedyogi':
for key, param in opt_proxy_dict.items():
delta_w = sum([(local_dict_list[client][key] - global_dict[key]) for client in clients_this_round]) / len(clients_this_round)
proxy_dict[key] = fed_args.fedopt_beta1 * proxy_dict[key] + (1 - fed_args.fedopt_beta1) * delta_w if round_idx > 0 else delta_w
delta_square = torch.square(proxy_dict[key])
opt_proxy_dict[key] = param - (1-fed_args.fedopt_beta2)*delta_square*torch.sign(param - delta_square)
global_dict[key] += fed_args.fedopt_eta * torch.div(proxy_dict[key], torch.sqrt(opt_proxy_dict[key])+fed_args.fedopt_tau)

elif fed_args.fed_alg == 'fedadam':
for key, param in opt_proxy_dict.items():
delta_w = sum([(local_dict_list[client][key] - global_dict[key]) for client in clients_this_round]) / len(clients_this_round)
proxy_dict[key] = fed_args.fedopt_beta1 * proxy_dict[key] + (1 - fed_args.fedopt_beta1) * delta_w if round_idx > 0 else delta_w
opt_proxy_dict[key] = fed_args.fedopt_beta2*param + (1-fed_args.fedopt_beta2)*torch.square(proxy_dict[key])
global_dict[key] += fed_args.fedopt_eta * torch.div(proxy_dict[key], torch.sqrt(opt_proxy_dict[key])+fed_args.fedopt_tau)

else: # Normal dataset-size-based aggregation
for key in global_dict.keys():
global_dict[key] = sum([local_dict_list[client][key] * sample_num_list[client] / sample_this_round for client in clients_this_round])

return global_dict, global_auxiliary
Loading

0 comments on commit 9ee5d04

Please sign in to comment.