From cfe80abc8421bacd0abafbb89612bec491004f05 Mon Sep 17 00:00:00 2001 From: Hiroaki Nakamura Date: Sat, 3 Aug 2019 07:28:05 +0000 Subject: [PATCH] Fix race condition in resty.limit.req using ngx.shared.DICT.set_when --- lib/resty/limit/req.lua | 103 ++++++++++++++++++++++------------------ 1 file changed, 58 insertions(+), 45 deletions(-) diff --git a/lib/resty/limit/req.lua b/lib/resty/limit/req.lua index 8b804fc..0e65d12 100644 --- a/lib/resty/limit/req.lua +++ b/lib/resty/limit/req.lua @@ -67,9 +67,6 @@ end -- sees an new incoming event -- the "commit" argument controls whether should we record the event in shm. --- FIXME we have a (small) race-condition window between dict:get() and --- dict:set() across multiple nginx worker processes. The size of the --- window is proportional to the number of workers. function _M.incoming(self, key, commit) local dict = self.dict local rate = self.rate @@ -79,36 +76,45 @@ function _M.incoming(self, key, commit) -- it's important to anchor the string value for the read-only pointer -- cdata: - local v = dict:get(key) - if v then - if type(v) ~= "string" or #v ~= rec_size then - return nil, "shdict abused by other users" - end - local rec = ffi_cast(const_rec_ptr_type, v) - local elapsed = now - tonumber(rec.last) - - -- print("elapsed: ", elapsed, "ms") - - -- we do not handle changing rate values specifically. the excess value - -- can get automatically adjusted by the following formula with new rate - -- values rather quickly anyway. - excess = max(tonumber(rec.excess) - rate * abs(elapsed) / 1000 + 1000, - 0) - - -- print("excess: ", excess) - - if excess > self.burst then - return nil, "rejected" - end - - else - excess = 0 - end - - if commit then - rec_cdata.excess = excess - rec_cdata.last = now - dict:set(key, ffi_str(rec_cdata, rec_size)) + while true do + local v = dict:get(key) + if v then + if type(v) ~= "string" or #v ~= rec_size then + return nil, "shdict abused by other users" + end + local rec = ffi_cast(const_rec_ptr_type, v) + local elapsed = now - tonumber(rec.last) + + -- print("elapsed: ", elapsed, "ms") + + -- we do not handle changing rate values specifically. the excess value + -- can get automatically adjusted by the following formula with new rate + -- values rather quickly anyway. + excess = max(tonumber(rec.excess) - rate * abs(elapsed) / 1000 + 1000, + 0) + + -- print("excess: ", excess) + + if excess > self.burst then + return nil, "rejected" + end + + else + excess = 0 + end + + if commit then + rec_cdata.excess = excess + rec_cdata.last = now + local succ, err, forcible = dict:set_when(key, v, ffi_str(rec_cdata, rec_size)) + if succ then + break + elseif err ~= "already modified" then + return nil, err + end + else + break + end end -- return the delay in seconds, as well as excess @@ -120,22 +126,29 @@ function _M.uncommit(self, key) assert(key) local dict = self.dict - local v = dict:get(key) - if not v then - return nil, "not found" - end + while true do + local v = dict:get(key) + if not v then + return nil, "not found" + end - if type(v) ~= "string" or #v ~= rec_size then - return nil, "shdict abused by other users" - end + if type(v) ~= "string" or #v ~= rec_size then + return nil, "shdict abused by other users" + end - local rec = ffi_cast(const_rec_ptr_type, v) + local rec = ffi_cast(const_rec_ptr_type, v) - local excess = max(tonumber(rec.excess) - 1000, 0) + local excess = max(tonumber(rec.excess) - 1000, 0) - rec_cdata.excess = excess - rec_cdata.last = rec.last - dict:set(key, ffi_str(rec_cdata, rec_size)) + rec_cdata.excess = excess + rec_cdata.last = rec.last + local succ, err, forcible = dict:set_when(key, v, ffi_str(rec_cdata, rec_size)) + if succ then + break + elseif err ~= "already modified" then + return nil, err + end + end return true end