forked from Victorwz/LongMem
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvalidate.py
287 lines (244 loc) · 9.5 KB
/
validate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Train a new model on one or across multiple GPUs.
"""
# import unilm
import argparse
import logging
import math
import json
import os
import sys
from typing import Any, Callable, Dict, List, Optional, Tuple
# We need to setup root logger before importing any fairseq libraries.
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=sys.stdout,
)
logger = logging.getLogger("fairseq_cli.train")
import numpy as np
import torch
from omegaconf import DictConfig, OmegaConf
from fairseq import checkpoint_utils, options, quantization_utils, tasks, utils
from fairseq.data import data_utils, iterators
from fairseq.data.plasma_utils import PlasmaStore
from fairseq.dataclass.configs import FairseqConfig
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
from fairseq.distributed import fsdp_enable_wrap, fsdp_wrap
from fairseq.distributed import utils as distributed_utils
from fairseq.file_io import PathManager
from fairseq.logging import meters, metrics, progress_bar
from fairseq.model_parallel.megatron_trainer import MegatronTrainer
from fairseq.trainer import Trainer
def main(cfg: FairseqConfig) -> None:
if isinstance(cfg, argparse.Namespace):
cfg = convert_namespace_to_omegaconf(cfg)
utils.import_user_module(cfg.common)
if (
distributed_utils.is_master(cfg.distributed_training)
and "job_logging_cfg" in cfg
):
# make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg))
assert (
cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
), "Must specify batch size either with --max-tokens or --batch-size"
metrics.reset()
if cfg.common.log_file is not None:
handler = logging.FileHandler(filename=cfg.common.log_file)
logger.addHandler(handler)
if distributed_utils.is_master(cfg.distributed_training):
checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir)
# Print args
logger.info(cfg)
if cfg.checkpoint.write_checkpoints_asynchronously:
try:
import iopath # noqa: F401
except ImportError:
logging.exception(
"Asynchronous checkpoint writing is specified but iopath is "
"not installed: `pip install iopath`"
)
return
# Setup task, e.g., translation, language modeling, etc.
print(cfg.task)
task = tasks.setup_task(cfg.task)
assert cfg.criterion, "Please specify criterion to train a model"
# Build model and criterion
if cfg.distributed_training.ddp_backend == "fully_sharded":
with fsdp_enable_wrap(cfg.distributed_training):
model = fsdp_wrap(task.build_model(cfg.model))
else:
print(cfg.model)
model = task.build_model(cfg.model)
criterion = task.build_criterion(cfg.criterion)
logger.info(model)
logger.info("task: {}".format(task.__class__.__name__))
logger.info("model: {}".format(model.__class__.__name__))
logger.info("criterion: {}".format(criterion.__class__.__name__))
logger.info(
"num. shared model params: {:,} (num. trained: {:,})".format(
sum(
p.numel() for p in model.parameters() if not getattr(p, "expert", False)
),
sum(
p.numel()
for p in model.parameters()
if not getattr(p, "expert", False) and p.requires_grad
),
)
)
logger.info(
"num. expert model params: {} (num. trained: {})".format(
sum(p.numel() for p in model.parameters() if getattr(p, "expert", False)),
sum(
p.numel()
for p in model.parameters()
if getattr(p, "expert", False) and p.requires_grad
),
)
)
# Load valid dataset (we load training data below, based on the latest checkpoint)
# We load the valid dataset AFTER building the model
data_utils.raise_if_valid_subsets_unintentionally_ignored(cfg)
# (optionally) Configure quantization
if cfg.common.quantization_config_path is not None:
quantizer = quantization_utils.Quantizer(
config_path=cfg.common.quantization_config_path,
max_epoch=cfg.optimization.max_epoch,
max_update=cfg.optimization.max_update,
)
else:
quantizer = None
# Build trainer
if cfg.common.model_parallel_size == 1:
trainer = Trainer(cfg, task, model, criterion, quantizer)
else:
trainer = MegatronTrainer(cfg, task, model, criterion)
logger.info(
"training on {} devices (GPUs/TPUs)".format(
cfg.distributed_training.distributed_world_size
)
)
logger.info(
"max tokens per device = {} and max sentences per device = {}".format(
cfg.dataset.max_tokens,
cfg.dataset.batch_size,
)
)
max_epoch = cfg.optimization.max_epoch
acc_list = []
ppl_list = []
for epoch in range(max_epoch):
np.random.seed(epoch)
utils.set_torch_seed(epoch)
task.load_dataset("valid", combine=False, epoch=epoch)
stats = validate(cfg, trainer, task, epoch, "valid")
if "accuracy" in stats.keys():
acc_list.append(stats["accuracy"])
if "ppl" in stats.keys():
ppl_list.append(stats["ppl"])
save_dict = {}
if len(acc_list) > 0:
save_dict["acc"] = sum(acc_list) / len(acc_list)
save_dict["acc_detail"] = acc_list
if len(ppl_list) > 0:
save_dict["ppl"] = sum(ppl_list) / len(ppl_list)
save_dict["ppl_detail"] = ppl_list
with open(cfg.model.result_path, 'w') as fp:
fp.write(json.dumps(save_dict, indent=1))
print(save_dict)
def validate(
cfg: DictConfig,
trainer: Trainer,
task: tasks.FairseqTask,
epoch,
subset: str,
) -> List[Optional[float]]:
"""Evaluate the model on the validation set(s) and return the losses."""
if cfg.dataset.fixed_validation_seed is not None:
# set fixed seed for every validation
utils.set_torch_seed(cfg.dataset.fixed_validation_seed)
trainer.begin_valid_epoch(epoch)
# Initialize data iterator
itr = trainer.get_valid_iterator(subset).next_epoch_itr(
shuffle=False, set_dataset_epoch=False # use a fixed valid set
)
if cfg.common.tpu:
itr = utils.tpu_data_loader(itr)
progress = progress_bar.progress_bar(
itr,
log_format=cfg.common.log_format,
log_interval=cfg.common.log_interval,
epoch=epoch,
prefix=f"valid on '{subset}' subset",
tensorboard_logdir=(
cfg.common.tensorboard_logdir
if distributed_utils.is_master(cfg.distributed_training)
else None
),
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
wandb_project=(
cfg.common.wandb_project
if distributed_utils.is_master(cfg.distributed_training)
else None
),
wandb_run_name=os.environ.get(
"WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)
),
)
# create a new root metrics aggregator so validation metrics
# don't pollute other aggregators (e.g., train meters)
with metrics.aggregate(new_root=True) as agg:
for i, sample in enumerate(progress):
if (
cfg.dataset.max_valid_steps is not None
and i > cfg.dataset.max_valid_steps
):
break
trainer.valid_step(sample)
# log validation stats
stats = get_valid_stats(cfg, trainer, agg.get_smoothed_values())
if hasattr(task, "post_validate"):
task.post_validate(trainer.get_model(), stats, agg)
progress.print(stats, tag=subset, step=trainer.get_num_updates())
return stats
def get_valid_stats(
cfg: DictConfig, trainer: Trainer, stats: Dict[str, Any]
) -> Dict[str, Any]:
stats["num_updates"] = trainer.get_num_updates()
if hasattr(checkpoint_utils.save_checkpoint, "best"):
key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric)
best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min
stats[key] = best_function(
checkpoint_utils.save_checkpoint.best,
stats[cfg.checkpoint.best_checkpoint_metric],
)
return stats
def cli_main(
modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None
) -> None:
parser = options.get_training_parser()
args = options.parse_args_and_arch(parser, modify_parser=modify_parser)
cfg = convert_namespace_to_omegaconf(args)
if cfg.common.use_plasma_view:
server = PlasmaStore(path=cfg.common.plasma_path)
logger.info(
f"Started plasma server pid {server.server.pid} {cfg.common.plasma_path}"
)
if args.profile:
with torch.cuda.profiler.profile():
with torch.autograd.profiler.emit_nvtx():
distributed_utils.call_main(cfg, main)
else:
distributed_utils.call_main(cfg, main)
# if cfg.common.use_plasma_view:
# server.server.kill()
if __name__ == "__main__":
cli_main()