-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathoptim-adam-single.lua
38 lines (32 loc) · 1.14 KB
/
optim-adam-single.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
-- Adam
-- only for single worker
-- Author: Minwei Feng ([email protected])
require 'optim'
function optim.adamsingle(opfunc, w, config, state)
local config = config or {}
local state = state or config
local lr = config.lr
local beta1 = config.beta1
local beta2 = config.beta2
local epsilon = config.epsilon
local pc = config.pclient or nil
state.pversion = state.pversion or 0
local fx,dfdx = opfunc(w)
state.adam_t = state.adam_t or 0
state.adam_m = state.adam_m or torch.Tensor():resizeAs(dfdx):zero()
state.adam_v = state.adam_v or torch.Tensor():resizeAs(dfdx):zero()
state.adam_d = state.adam_d or torch.Tensor():resizeAs(dfdx):zero()
state.adam_t = state.adam_t + 1
state.adam_m:mul(beta1):add(1-beta1, dfdx)
state.adam_v:mul(beta2):addcmul(1-beta2, dfdx, dfdx)
state.adam_d:copy(state.adam_v):sqrt():add(epsilon)
local beta1_t = 1 - math.pow(beta1, state.adam_t )
local beta2_t = 1 - math.pow(beta2, state.adam_t )
local lr_t = lr * math.sqrt(beta2_t)/beta1_t
w:addcdiv(-lr_t, state.adam_m, state.adam_d)
state.pversion = state.pversion + 1
-- send
pc:async_send_param()
pc:wait()
return w,{fx}
end