-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_me.lua
38 lines (30 loc) · 765 Bytes
/
run_me.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
require 'torch'
require 'cunn'
require "data.lua"
dofile "etc.lua"
local trainData
local trainLabel
torch.setdefaulttensortype('torch.FloatTensor')
math.randomseed(os.time())
dofile "data.lua"
dofile "model.lua"
dofile "train.lua"
dofile "test.lua"
if mode == "train" then
trainData, trainLabel = load_data()
fp_err = io.open("result/loss_" .. testScale .. ".txt","a")
fp_PSNR = io.open("result/PSNR_" .. testScale .. ".txt","a")
while epoch <= epochNum do
train(trainData, trainLabel)
epoch = epoch + 1
err = tot_error/cnt_error
fp_err:write(err,"\n")
test()
fp_PSNR:write(PSNR_sum/testDataSz,"\n")
end
fp_err:close()
fp_PSNR:close()
end
if mode == "test" then
test()
end