Skip to content

Commit

Permalink
add mistralai/Mistral-7B-v0.1 to finetune workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
harborn committed Jan 5, 2024
1 parent cd05f0e commit 9a077eb
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 2 deletions.
7 changes: 6 additions & 1 deletion .github/workflows/workflow_finetune.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
name: finetune test
strategy:
matrix:
model: [ EleutherAI/gpt-j-6b, meta-llama/Llama-2-7b-chat-hf, gpt2, bigscience/bloom-560m, facebook/opt-125m, mosaicml/mpt-7b-chat, huggyllama/llama-7b ]
model: [ EleutherAI/gpt-j-6b, meta-llama/Llama-2-7b-chat-hf, gpt2, bigscience/bloom-560m, facebook/opt-125m, mosaicml/mpt-7b-chat, huggyllama/llama-7b, mistralai/Mistral-7B-v0.1 ]
isPR:
- ${{inputs.ci_type == 'pr'}}

Expand All @@ -43,6 +43,7 @@ jobs:
include:
- { model: "EleutherAI/gpt-j-6b"}
- { model: "meta-llama/Llama-2-7b-chat-hf"}
- { model: "mistralai/Mistral-7B-v0.1"}

runs-on: self-hosted

Expand Down Expand Up @@ -96,6 +97,10 @@ jobs:
result['General']["gpt_base_model"] = True
else:
result['General']["gpt_base_model"] = False
if "${{ matrix.model }}" == "mistralai/Mistral-7B-v0.1":
result['General']['lora_config']['target_modules'] = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", "lm_head",]
else:
result['General']['lora_config']['target_modules'] = None
if "${{ matrix.model }}" == "meta-llama/Llama-2-7b-chat-hf":
result['General']["config"]["use_auth_token"] = "${{ env.HF_ACCESS_TOKEN }}"
else:
Expand Down
2 changes: 1 addition & 1 deletion finetune/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def train_func(config: Dict[str, Any]):
trainer = common.trainer.Trainer.registory.get("DefaultTrainer")(config = {
"num_train_epochs": config["Training"]["epochs"],
"max_train_step": config["Training"].get("max_train_steps", None),
"log_step": 1,
"log_step": config["General"].get("log_step", 10),
"output": config["General"]["output_dir"],
"dataprocesser": {
"type": "GeneralProcesser",
Expand Down
1 change: 1 addition & 0 deletions finetune/finetune.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
General:
base_model: EleutherAI/gpt-j-6b
gpt_base_model: true
log_step: 10
output_dir: /tmp/llm-ray/output
checkpoint_dir: /tmp/llm-ray/checkpoint
config:
Expand Down

0 comments on commit 9a077eb

Please sign in to comment.