-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_torcs.py
88 lines (77 loc) · 3.41 KB
/
run_torcs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import argparse
import importlib
from env import torcs_envs as torcs
parser = argparse.ArgumentParser(description="TORCS")
parser.add_argument(
"--seed", type=int, default=777, help="random seed for reproducibility")
parser.add_argument(
"--algo", type=str, default="sac-lstm", help="choose an algorithm")
parser.add_argument(
"--test", dest="test", action="store_true", help="test mode (no training)")
parser.add_argument(
"--load-from", type=str, help="load the saved model and optimizer at the beginning")
parser.add_argument(
"--on-render", dest="render", action="store_true", help="turn on rendering")
parser.add_argument(
"--log", dest="log", action="store_true", help="turn on logging")
parser.add_argument(
"--save-period", type=int, default=50, help="save model period")
parser.add_argument(
"--episode-num", type=int, default=10000, help="total episode num")
parser.add_argument(
"--max-episode-steps", type=int, default=10000, help="max episode step")
parser.add_argument(
"--interim-test-num", type=int, default=1, help="interim test number")
parser.add_argument(
"--relaunch-period", type=int, default=5, help="environment relaunch period")
parser.add_argument(
"--test-period", type=int, default=100, help="test period")
parser.add_argument(
"--num-stack", type=int, default=4, help="number of states to stack")
parser.add_argument(
"--reward-type", type=str, default="extra_github", help="reward type")
parser.add_argument(
"--track", type=str, default="none", help="track name")
parser.add_argument(
"--use-state-filter", dest="state_filter", action="store_true", help="apply filter to observations")
parser.add_argument(
"--use-action-filter", dest="action_filter", action="store_true", help="apply filter to actions")
parser.set_defaults(test=False)
parser.set_defaults(load_from=None)
parser.set_defaults(render=False)
parser.set_defaults(log=True)
parser.set_defaults(state_filter=False)
parser.set_defaults(action_filter=False)
args = parser.parse_args()
def main():
state_filter = None if not args.state_filter else [1., 3., 10.] # example filter (previous to recent)
action_filter = None if not args.action_filter else [1., 3., 10.]
if args.algo == "dqn":
env = torcs.DiscretizedEnv(nstack=1,
reward_type=args.reward_type,
track=args.track,
state_filter=state_filter,
action_filter=None,
action_count=21)
elif args.algo == "sac":
env = torcs.ContinuousEnv(nstack=4,
reward_type=args.reward_type,
track=args.track,
state_filter=state_filter,
action_filter=action_filter)
elif args.algo == "sac-lstm":
env = torcs.ContinuousEnv(nstack=1,
reward_type=args.reward_type,
track=args.track,
state_filter=state_filter,
action_filter=action_filter)
else:
raise Exception("Invalid algorithm!")
module = importlib.import_module("torcs." + args.algo)
agent = module.init(env, args)
if args.test:
agent.test()
else:
agent.train()
if __name__ == "__main__":
main()