forked from mosaicml/examples
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
57 lines (41 loc) · 1.92 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Entrypoint that runs the Composer trainer on a provided YAML hparams file."""
import sys
import tempfile
import warnings
from typing import Type
from composer.loggers.logger import LogLevel
from composer.loggers.logger_hparams import WandBLoggerHparams
from composer.trainer import TrainerHparams
from composer.utils import dist
def warning_on_one_line(message: str, category: Type[Warning], filename: str, lineno: int, file=None, line=None):
# From https://stackoverflow.com/questions/26430861/make-pythons-warnings-warn-not-mention-itself
return f'{category.__name__}: {message} (source: {filename}:{lineno})\n'
def main() -> None:
warnings.formatwarning = warning_on_one_line
if len(sys.argv) == 1:
sys.argv = [sys.argv[0], "--help"]
hparams = TrainerHparams.create(cli_args=True) # reads cli args from sys.argv
# if using wandb, store the config inside the wandb run
for logger_hparams in hparams.loggers:
if isinstance(logger_hparams, WandBLoggerHparams):
logger_hparams.config = hparams.to_dict()
trainer = hparams.initialize_object()
# Only log the config once, since it should be the same on all ranks.
if dist.get_global_rank() == 0:
with tempfile.NamedTemporaryFile(mode="x+") as f:
f.write(hparams.to_yaml())
trainer.logger.file_artifact(LogLevel.FIT,
artifact_name=f"{trainer.logger.run_name}/hparams.yaml",
file_path=f.name,
overwrite=True)
# Print the config to the terminal and log to artifact store if on each local rank 0
if dist.get_local_rank() == 0:
print("*" * 30)
print("Config:")
print(hparams.to_yaml())
print("*" * 30)
trainer.fit()
if __name__ == "__main__":
main()