-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path__main__.py
50 lines (38 loc) · 1.3 KB
/
__main__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import hydra
import os
import sys
from typing import TYPE_CHECKING
if __name__ == '__main__':
xla_flags = os.environ.get('XLA_FLAGS', '')
if xla_flags != '':
xla_flags = ' ' + xla_flags
os.environ['XLA_FLAGS'] = f'--xla_gpu_deterministic_ops=true{xla_flags}'
if 'MUJOCO_GL' not in os.environ:
os.environ['MUJOCO_GL'] = 'egl'
from .main import main
from .config.root import register_all_configs, get_extra_overrides
if TYPE_CHECKING:
from .config.root import RootConfig
def preprocess_argv() -> None:
overrides: dict[str, str] = {}
override_idx: dict[str, int] = {}
for i, arg in enumerate(sys.argv[1:]):
try:
key, value = arg.split('=')
except Exception:
continue
override_idx[key] = i + 1
overrides[key] = value
extra_overrides = get_extra_overrides(overrides)
for key, value in extra_overrides.items():
if key in override_idx.keys():
sys.argv[override_idx[key]] = f'{key}={value}'
else:
sys.argv.append(f'{key}={value}')
if __name__ == '__main__':
register_all_configs()
preprocess_argv()
@hydra.main(config_path='../config', config_name='config', version_base=None)
def wrapped_main(cfg: 'RootConfig') -> None:
main(cfg)
wrapped_main()