-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathresnet_multicls.py
50 lines (41 loc) · 1.34 KB
/
resnet_multicls.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import sys
import argparse
import yaml
from func.trainer import Trainer, Train_Project
def dltrain(args):
config_path = args.config_yaml
config = yaml.safe_load(open(config_path, "r"))
if args.gpu_device is not None:
gpu_device = args.gpu_device.split(",")
try:
gpu_device = [int(0) for x in gpu_device]
except TypeError:
print("Please provide a list of GPU devices in Integers")
sys.exit(1)
else:
gpu_device = None
train_prj = Train_Project(config)
train_prj.load_train_objs()
# prepare the dataloader
train_prj.logger.info("generate training, validation datasets, and the model")
train_prj.prepare_dataloader()
print(vars(train_prj))
trainer = Trainer(train_prj, gpu_device)
trainer.train()
print("Finished Training")
# cleanup()
def main():
parser = argparse.ArgumentParser(description="")
reqdarg = parser.add_argument_group("required arguments")
reqdarg.add_argument(
"-c", dest="config_yaml", type=str, required=True, help="config yaml diretory"
)
optarg = parser.add_argument_group("optional arguments")
optarg.add_argument(
"-g", dest="gpu_device", type=str, help="avaiable gpu device(s)"
)
args = parser.parse_args()
dltrain(args)
return
if __name__ == "__main__":
main()