forked from Moodstocks/gtsrb
-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathmatching.lua
87 lines (64 loc) · 1.88 KB
/
matching.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
local get_features = function(cnn, index_conv, input)
local index = 0
-- make a forward pass
cnn:forward(input)
local output = cnn.modules[index_conv].output
local size = output:size(1) * output:size(2) * output:size(3)
return output:reshape(1, size)
end
local all_filled = function(table, size)
local i
for i=1,size do
if not table[i] then
return false
end
end
return true
end
local select_references = function(cnn, index_conv, dataset)
local references = {}
local index = 1
while not all_filled(references, 43) do
sample = dataset[index]
if not references[sample[2][1]] then
references[sample[2][1]] = get_features(cnn, index_conv, sample[1])
end
index = index + 1
end
return references
end
local distance_function = function(tensor1, tensor2)
local diff = tensor1 - tensor2
return torch.sqrt(diff:transpose(1,2):dot(diff))
end
local classify_intput = function(cnn, index_conv, references, sample)
local sample_features = get_features(cnn, index_conv, sample)
local best_choice = 0
local best_distance = 1/0
local dist
for label, ref_features in ipairs(references) do
dist = distance_function(ref_features, sample_features)
if dist < best_distance then
best_distance = dist
best_choice = label
end
end
return best_choice
end
local test_matching = function(cnn, index_conv, dataset)
local references = select_references(cnn, index_conv, dataset)
local nbr_elements = 0
local nbr_false = 0
local prediction
for index, sample in ipairs(dataset) do
prediction = classify_intput(cnn, index_conv, references, sample[1])
if prediction ~= sample[2][1] then
nbr_false = nbr_false + 1
end
nbr_elements = nbr_elements + 1
end
print('Error rate using matching on the given set is: ' .. nbr_false/nbr_elements .. '.')
end
return {
test_matching = test_matching
}