Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race condition in resty.limit.req using ngx.shared.DICT.set_when #49

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 58 additions & 45 deletions lib/resty/limit/req.lua
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,6 @@ end

-- sees an new incoming event
-- the "commit" argument controls whether should we record the event in shm.
-- FIXME we have a (small) race-condition window between dict:get() and
-- dict:set() across multiple nginx worker processes. The size of the
-- window is proportional to the number of workers.
function _M.incoming(self, key, commit)
local dict = self.dict
local rate = self.rate
Expand All @@ -79,36 +76,45 @@ function _M.incoming(self, key, commit)

-- it's important to anchor the string value for the read-only pointer
-- cdata:
local v = dict:get(key)
if v then
if type(v) ~= "string" or #v ~= rec_size then
return nil, "shdict abused by other users"
end
local rec = ffi_cast(const_rec_ptr_type, v)
local elapsed = now - tonumber(rec.last)

-- print("elapsed: ", elapsed, "ms")

-- we do not handle changing rate values specifically. the excess value
-- can get automatically adjusted by the following formula with new rate
-- values rather quickly anyway.
excess = max(tonumber(rec.excess) - rate * abs(elapsed) / 1000 + 1000,
0)

-- print("excess: ", excess)

if excess > self.burst then
return nil, "rejected"
end

else
excess = 0
end

if commit then
rec_cdata.excess = excess
rec_cdata.last = now
dict:set(key, ffi_str(rec_cdata, rec_size))
while true do
local v = dict:get(key)
if v then
if type(v) ~= "string" or #v ~= rec_size then
return nil, "shdict abused by other users"
end
local rec = ffi_cast(const_rec_ptr_type, v)
local elapsed = now - tonumber(rec.last)

-- print("elapsed: ", elapsed, "ms")

-- we do not handle changing rate values specifically. the excess value
-- can get automatically adjusted by the following formula with new rate
-- values rather quickly anyway.
excess = max(tonumber(rec.excess) - rate * abs(elapsed) / 1000 + 1000,
0)

-- print("excess: ", excess)

if excess > self.burst then
return nil, "rejected"
end

else
excess = 0
end

if commit then
rec_cdata.excess = excess
rec_cdata.last = now
local succ, err, forcible = dict:set_when(key, v, ffi_str(rec_cdata, rec_size))
if succ then
break
elseif err ~= "already modified" then
return nil, err
end
else
break
end
end

-- return the delay in seconds, as well as excess
Expand All @@ -120,22 +126,29 @@ function _M.uncommit(self, key)
assert(key)
local dict = self.dict

local v = dict:get(key)
if not v then
return nil, "not found"
end
while true do
local v = dict:get(key)
if not v then
return nil, "not found"
end

if type(v) ~= "string" or #v ~= rec_size then
return nil, "shdict abused by other users"
end
if type(v) ~= "string" or #v ~= rec_size then
return nil, "shdict abused by other users"
end

local rec = ffi_cast(const_rec_ptr_type, v)
local rec = ffi_cast(const_rec_ptr_type, v)

local excess = max(tonumber(rec.excess) - 1000, 0)
local excess = max(tonumber(rec.excess) - 1000, 0)

rec_cdata.excess = excess
rec_cdata.last = rec.last
dict:set(key, ffi_str(rec_cdata, rec_size))
rec_cdata.excess = excess
rec_cdata.last = rec.last
local succ, err, forcible = dict:set_when(key, v, ffi_str(rec_cdata, rec_size))
if succ then
break
elseif err ~= "already modified" then
return nil, err
end
end
return true
end

Expand Down