-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathDemo_Test.m
124 lines (100 loc) · 4.37 KB
/
Demo_Test.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% @article{Canh2018_MSCSNet,
% title={Multi-Scale Deep Compressive Sensing Network},
% author={Thuong, Nguyen Canh and Byeungwoo, Jeon},
% conference={IEEE International Conference on Visual Comunication and Image Processing},
% year={2018}
% }
% by Thuong Nguyen Canh (9/2018)
% https://github.com/AtenaKid
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% You need to install Matconvnet in order to run this code
warning('off','all')
addpath('D:\matconvnet-1.0-beta25\matlab\mex'); % Link to your Matconvnet Mex file
% addpath('matconvnet-1.0-beta25\matlab');
addpath('.\utilities');
folderTest = 'Classic13_512';
networkTest = {'CSNet' 'W-DCS1' 'W-DCS2' 'W-DCS3' 'SS-DCS1' 'SS-DCS2' 'SS-DCS3' ...
'P-DCS1' 'P-DCS2' 'P-DCS3' 'DoC-DCS1' 'DoC-DCS2' 'DoC-DCS3'}; % 10
showResult = 0;
writeRecon = 1;
featureSize = 64;
blkSize = 32;
isLearnMtx = [1, 0];
network = networkTest{1};
for samplingRate = [0.1:0.1:0.3]
modelName = [network '_r' num2str(samplingRate)]; %%% model name
data = load(fullfile('models', network ,[modelName,'.mat']));
net = dagnn.DagNN.loadobj(data.net);
if strcmp(network,'CSNet')
net.renameVar('x0', 'input');
net.renameVar('x12', 'prediction');
else
net.removeLayer(net.layers(end).name) ;
end
net.mode = 'test';
net.move('gpu');
%%% read images
ext = {'*.jpg','*.png','*.bmp', '*.pgm', '*.tif'};
filePaths = [];
for i = 1 : length(ext)
filePaths = cat(1,filePaths, dir(fullfile('testsets',folderTest,ext{i})) );
end
PSNRs_CSNet = zeros(1,length(filePaths));
SSIMs_CSNet = zeros(1,length(filePaths));
count = 1;
allName = cell(1);
for i = 1:length(filePaths)
%%% read images
image = imread(fullfile('testsets', folderTest, filePaths(i).name));
[~,nameCur,extCur] = fileparts(filePaths(i).name);
allName{count} = nameCur;
if size(image,3) == 3
image = modcrop(image,32);
image = rgb2ycbcr(image);
image = image(:,:,1);
end
label = im2single(image);
if mod(size(label, 1), blkSize) ~= 0 || mod(size(label, 2), blkSize) ~= 0
continue
end
input = label;
input = gpuArray(input);
% net.conserveMemory = false;
tic
net.eval({'input', input}) ;
time(i) = toc;
out1 = net.getVarIndex('prediction') ;
output = gather(squeeze(gather(net.vars(out1).value)));
%output = res(end).x;
output = gather(output);
input = gather(input);
%%% calculate PSNR and SSIM
[PSNRCur_CSNet, SSIMCur_CSNet] = Cal_PSNRSSIM(im2uint8(label),im2uint8(output),0,0);
if showResult
%imshow(cat(2,im2uint8(label),im2uint8(output)));
%title([filePaths(i).name,' ',num2str(PSNRCur_CSNet,'%2.2f'),'dB',' ',num2str(SSIMCur_CSNet,'%2.4f')])
%drawnow;
display([' ' filePaths(i).name,' ',num2str(PSNRCur_CSNet,'%2.2f'),'dB',' ',num2str(SSIMCur_CSNet,'%2.3f')])
end
PSNRs_CSNet(i) = PSNRCur_CSNet;
SSIMs_CSNet(i) = SSIMCur_CSNet;
% save results for current image
if writeRecon
folder = ['Results\2Image_' network ];
if ~exist(folder), mkdir(folder); end
fileName = [folder '\' folderTest '_' allName{count} '_subrate' num2str(samplingRate) '.png'];
imwrite(im2uint8(output), fileName );
count = count + 1;
end
end
% save results for current image
folder = ['Results\1Text_' network ];
if ~exist(folder), mkdir(folder); end
imgName = [folderTest ];
fileName = [folder '\' imgName '_subrate' num2str(samplingRate) '.txt'];
write_txt(fileName, allName, samplingRate, PSNRs_CSNet, SSIMs_CSNet, time);
disp(['Average, subrate ' num2str(samplingRate) ': ' num2str(mean(PSNRs_CSNet), ...
'%2.3f') 'dB, SSIM: ', num2str(mean(SSIMs_CSNet), '%2.4f'), ', time: ', num2str(mean(time), '%2.4f')]);
end