-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathBFS.py
123 lines (109 loc) · 2.96 KB
/
BFS.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import TNAS.Train_tester
import time
from nas_201_api import NASBench201API as API
# Operation tree design
node_ops = [[]]*9 # operations on id-th node
e = [[]]*9 # edge list
par = [-1]*9 # parent list
e[0] = [1,2]
e[2] = [3,4]
e[3] = [5,6]
e[4] = [7,8]
par[1] = 0
par[2] = 0
par[3] = 2
par[4] = 2
par[5] = 3
par[6] = 3
par[7] = 4
par[8] = 4
node_ops[0] = ['none','nor_conv_1x1','nor_conv_3x3','skip_connect','avg_pool_3x3']
node_ops[1] = ['none']
node_ops[2] = ['nor_conv_1x1','nor_conv_3x3','skip_connect','avg_pool_3x3']
node_ops[3] = ['nor_conv_1x1','nor_conv_3x3']
node_ops[4] = ['skip_connect','avg_pool_3x3']
node_ops[5] = ['nor_conv_1x1']
node_ops[6] = ['nor_conv_3x3']
node_ops[7] = ['skip_connect']
node_ops[8] = ['avg_pool_3x3']
# Cell's DAG edge list
L = [(0,1), (0,2), (0,3), (1,2), (1,3), (2,3)]
# Encoding candidates and operations
id_op = ['none','nor_conv_1x1','nor_conv_3x3','skip_connect','avg_pool_3x3']
masks = []
def To_mask(num):
res = ''
while(num > 0):
res += str(num % 2)
num = int(num/2)
while(len(res) < 6):
res += '0'
res = res[::-1]
return res
def Build_masks():
mask_len = 2**6 - 1
for i in range(mask_len + 1):
masks.append(To_mask(i))
def Build_cand_arch(Arch, mask):
Cand_arch = []
Len = len(mask)
for i in range(Len):
if len(e[Arch[i]]) < 2:
Cand_arch.append(Arch[i])
continue
o0 = e[Arch[i]][0]
o1 = e[Arch[i]][1]
b = mask[i]
op = o0
if b == '1':
op = o1
Cand_arch.append(op)
return Cand_arch
def Score(Arch):
return TNAS.Train_tester.get_metric(Arch)
def Check_connected(Arch):
fr = [True]*4
fr[0] = False
for i in range(6):
s, t = L[i]
if Arch[i] == 1:
continue
fr[t] = fr[s]
return not fr[3]
def BFS_T_o():
Arch = [0,0,0,0,0,0]
q = []
q.insert(0,0) # queue.push()
while(len(q) > 0):
u = q[-1] # queue.top()
print(u)
q.pop() # queue.pop()
for v in e[u]:
q.insert(0,v)
if len(e[u]) == 0:
continue
o0 = e[u][0]
o1 = e[u][1]
Cur_arch = Build_cand_arch(Arch, To_mask(0))
score = Score(Cur_arch)
for mask in masks:
Cand_arch = Build_cand_arch(Arch, mask)
print(Cand_arch)
start_time = time.time()
if not Check_connected(Cand_arch):
continue
if Cur_arch == Cand_arch:
continue
cand_score = Score(Cand_arch)
if score < cand_score:
Cur_arch = Cand_arch
score = cand_score
print(Cand_arch, time.time() - start_time)
Arch = Cur_arch
print(Arch)
Build_masks()
BFS_T_o()