-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
Copy pathconvert_gpu_cpu_checkpoint.lua
74 lines (63 loc) · 2.13 KB
/
convert_gpu_cpu_checkpoint.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
--[[
A quick patch for converting GPU checkpoints to
CPU checkpoints until I implement a more long-term
solution. Takes the path to the model and creates
a file in the same location and path, but with _cpu.t7
appended.
]]--
require 'torch'
require 'nn'
require 'nngraph'
require 'lfs'
require 'util.OneHot'
require 'util.misc'
cmd = torch.CmdLine()
cmd:text()
cmd:text('Sample from a character-level language model')
cmd:text()
cmd:text('Options')
cmd:argument('-model','GPU model checkpoint to convert')
cmd:option('-gpuid',0,'which gpu to use. -1 = use CPU')
cmd:option('-opencl',0,'use OpenCL (instead of CUDA)')
cmd:text()
-- parse input params
opt = cmd:parse(arg)
-- check that cunn/cutorch are installed if user wants to use the GPU
if opt.gpuid >= 0 and opt.opencl == 0 then
local ok, cunn = pcall(require, 'cunn')
local ok2, cutorch = pcall(require, 'cutorch')
if not ok then print('package cunn not found!') end
if not ok2 then print('package cutorch not found!') end
if ok and ok2 then
print('using CUDA on GPU ' .. opt.gpuid .. '...')
cutorch.setDevice(opt.gpuid + 1) -- note +1 to make it 0 indexed! sigh lua
else
print('Error, no GPU available?')
os.exit()
end
end
-- check that clnn/cltorch are installed if user wants to use OpenCL
if opt.gpuid >= 0 and opt.opencl == 1 then
local ok, cunn = pcall(require, 'clnn')
local ok2, cutorch = pcall(require, 'cltorch')
if not ok then print('package clnn not found!') end
if not ok2 then print('package cltorch not found!') end
if ok and ok2 then
print('using OpenCL on GPU ' .. opt.gpuid .. '...')
cltorch.setDevice(opt.gpuid + 1) -- note +1 to make it 0 indexed! sigh lua
else
print('Error, no GPU available?')
os.exit()
end
end
print('loading ' .. opt.model)
checkpoint = torch.load(opt.model)
protos = checkpoint.protos
-- convert the networks to be CPU models
for k,v in pairs(protos) do
print('converting ' .. k .. ' to CPU')
protos[k]:double()
end
local savefile = opt.model .. '_cpu.t7' -- append "cpu.t7" to filename
torch.save(savefile, checkpoint)
print('saved ' .. savefile)