-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathevaluate.py
72 lines (54 loc) · 3.1 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import argparse
from train import TrainMeanField
parser = argparse.ArgumentParser()
parser.add_argument('--wandb_ids', default = ["7knl4x32"],type = str, help='wandb ids', nargs = "+")
parser.add_argument('--GPUs', default=["0"], type = str, help='Define Nb', nargs = "+")
parser.add_argument('--memory', default=0.92, type = float, help="GPU memory")
parser.add_argument('--expl', default=[0.], type = float, help="amount of exploration", nargs = "+")
parser.add_argument('--seeds', default=1, type = int, help="num_seeds")
parser.add_argument('--n_sampling_rounds', default=1000, type = int, help="Number of sampling rounds")
parser.add_argument('--sampling_modes', default=["temps"], type = str, help="eps or temps", nargs = "+")
parser.add_argument('--n_test_basis_states', default=400, type = int, help="eps or temps")
args = parser.parse_args()
# python evaluate.py --wandb_id hlmw6stl --GPU MIG-c69ed117-8436-51d1-b4db-183ea0228cd6 --exp 0. 0.01 0.05 --n_sampling_rounds 400 --sampling_modes temps
def good_wandb_ids():
fKL = {"f0cszhfv": {"n_diff_steps": 30, "T_start": 10} , "1838jcin": {"n_diff_steps": 60, "T_start": 10}, "hlmw6stl": {"n_diff_steps": 100, "T_start": 10}}
PPO = {"9lu3ahbm": {"n_diff_steps": 30, "T_start": 10}, "ts7dch2k": {"n_diff_steps": 30, "T_start": 4}}
REINFORCE = {"da75rnt": {"n_diff_steps": 10, "T_start": 10}}
# python evaluate.py --wandb_id bp0pthmf --GPU 7 --exp 0. --n_sampling_rounds 400 --sampling_modes temps
# python evaluate.py --wandb_id qkfzunur --GPU 4 --exp 0. --n_sampling_rounds 400 --n_test_basis_states 1200 --seeds 3 --sampling_modes temps
def _8x8_ablation():
# python evaluate.py --wandb_id 23n6g8f2 usjpghdv sywdxxhb --GPU MIG-c6766c68-2ea4-5e48-b9d4-f0d93f1beeed --exp 0. --n_sampling_rounds 1000 --seeds 3 --sampling_modes temps
pass
def evaluate():
devices = args.GPUs
device_str = ""
for idx, device in enumerate(devices):
if (idx != len(devices) - 1):
device_str += str(devices[idx]) + ","
else:
device_str += str(devices[idx])
print(device_str)
if (len(args.GPUs) > 1):
device_str = ""
for idx, device in enumerate(devices):
if (idx != len(devices) - 1):
device_str += str(devices[idx]) + ","
else:
device_str += str(devices[idx])
print(device_str, type(device_str))
else:
device_str = str(args.GPUs[0])
os.environ['CUDA_DEVICE_ORDER'] = "PCI_BUS_ID"
os.environ['CUDA_VISIBLE_DEVICES'] = device_str
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = str(args.memory)
#os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
config = {"jit": True}
for wandb_id in args.wandb_ids:
train = TrainMeanField(config, load_wandb_id=wandb_id)
train.wandb_old_run_id = wandb_id
for sampling_mode in args.sampling_modes:
train.test_ubiased_estimator(args.expl, args.seeds, args.n_sampling_rounds, sampling_mode, args.n_test_basis_states)
if(__name__ == "__main__"):
evaluate()