-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtrain.py
73 lines (58 loc) · 2.41 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# Adapted from Tevatron code
import logging
import sys
import torch
import wandb
from transformers import (
HfArgumentParser,
)
from src.dataset import TrainTextImageDataset
from src.collator import TrainTextImageDataCollator
from src.arguments import ModelArguments, DataArguments, TrainingArguments
from src.model import MMEBModel
from src.trainer import GradCacheLateProcessTrainer
from src.utils import print_rank
from src.model_utils import load_processor, get_backbone_name
logger = logging.getLogger(__name__)
def main():
# a hack for torch.distributed.launch: https://github.com/huggingface/transformers/issues/22171
for arg in sys.argv:
if arg.startswith("--local-rank="):
rank = arg.split("=")[1]
sys.argv.remove(arg)
sys.argv.append('--local_rank')
sys.argv.append(rank)
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model_args: ModelArguments
data_args: DataArguments
training_args: TrainingArguments
if 'wandb' in training_args.report_to:
if (torch.distributed.is_initialized() and torch.distributed.get_rank() == 0) or (not torch.distributed.is_initialized()):
print_rank('init wandb')
wandb.init(project=training_args.project_name, name=training_args.run_name, mode="online")
model = MMEBModel.build(model_args, training_args)
model_backbone = get_backbone_name(hf_config=model.config)
setattr(model_args, 'model_backbone', model_backbone)
setattr(training_args, 'model_backbone', model_backbone)
print_rank(f'model_backbone: {model_backbone}')
processor = load_processor(model_args)
setattr(model, 'processor', processor)
train_dataset = TrainTextImageDataset(data_args, model_args)
collator = TrainTextImageDataCollator(data_args, model_args, processor)
trainer_cls = GradCacheLateProcessTrainer
trainer = trainer_cls(
model=model,
processing_class=processor,
args=training_args,
train_dataset=train_dataset,
data_collator=collator,
max_length=data_args.max_len
)
train_dataset.trainer = trainer
trainer.train()
trainer.save_model(training_args.output_dir)
if trainer.is_world_process_zero():
processor.save_pretrained(training_args.output_dir)
if __name__ == "__main__":
main()