Skip to content

Commit

Permalink
Refactor error handling and bring code coverage to 100% (#6)
Browse files Browse the repository at this point in the history
* Refactor error handling and bring code coverage to 100%

Signed-off-by: Matteo Collina <[email protected]>

* fixup

Signed-off-by: Matteo Collina <[email protected]>

---------

Signed-off-by: Matteo Collina <[email protected]>
  • Loading branch information
mcollina authored Feb 2, 2024
1 parent f8126a4 commit ecc0e16
Show file tree
Hide file tree
Showing 7 changed files with 157 additions and 43 deletions.
4 changes: 0 additions & 4 deletions index.js

This file was deleted.

File renamed without changes.
38 changes: 20 additions & 18 deletions oauth-interceptor.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
'use strict'

const { createDecoder } = require('fast-jwt')
const { refreshAccessToken } = require('./utils')
const { refreshAccessToken } = require('./lib/utils')
const { RetryHandler } = require('undici')

const decode = createDecoder()
Expand All @@ -19,26 +19,11 @@ function getTokenState (token) {

const { exp } = decode(token)

if (!exp) return TOKEN_STATE.EXPIRED
if (exp <= (Date.now() + EXP_DIFF_MS) / 1000) return TOKEN_STATE.EXPIRED
if (exp <= (Date.now() + NEAR_EXP_DIFF_MS) / 1000) return TOKEN_STATE.NEAR_EXPIRATION
return TOKEN_STATE.VALID
}

let _requestingRefresh
function callRefreshToken (refreshEndpoint, refreshToken, clientId) {
if (_requestingRefresh) return _requestingRefresh

_requestingRefresh = refreshAccessToken({ refreshEndpoint, refreshToken, clientId })
_requestingRefresh.catch(() => _requestingRefresh = null)
_requestingRefresh.then((result) => {
_requestingRefresh = null
return result
})

return _requestingRefresh
}

function createOAuthInterceptor (options) {
let { accessToken } = { ...options }
const {
Expand All @@ -64,6 +49,16 @@ function createOAuthInterceptor (options) {
const refreshHost = iss
const client = clientId || sub

let _requestingRefresh
function callRefreshToken (refreshEndpoint, refreshToken, clientId) {
if (_requestingRefresh) return _requestingRefresh

_requestingRefresh = refreshAccessToken({ refreshEndpoint, refreshToken, clientId })
.finally(() => _requestingRefresh = null)

return _requestingRefresh
}

return dispatch => {
return function Intercept (opts, handler) {
if (!opts.oauthRetry && (origins.length > 0 && !origins.includes(opts.origin))) {
Expand All @@ -73,6 +68,9 @@ function createOAuthInterceptor (options) {

if (opts.oauthRetry) {
return callRefreshToken(refreshHost, refreshToken, client)
.catch(err => {
handler.onError(err)
})
.then(newAccessToken => {
accessToken = newAccessToken

Expand Down Expand Up @@ -111,18 +109,22 @@ function createOAuthInterceptor (options) {
...opts.headers,
authorization: `Bearer ${accessToken}`
}
dispatcher.emit('oauth:token-refreshed', newAccessToken)
return dispatch(opts, retryHandler)
}

switch (getTokenState(accessToken)) {
case TOKEN_STATE.EXPIRED:
return callRefreshToken(refreshHost, refreshToken, client)
.then(saveTokenAndRetry)
.catch(err => handler.onError(err))
.catch(err => {
handler.onError(err)
})
case TOKEN_STATE.NEAR_EXPIRATION:
callRefreshToken(refreshHost, refreshToken, client)
.then(newAccessToken => {
accessToken = newAccessToken
dispatcher.emit('oauth:token-refreshed', newAccessToken)
})
.catch(/* do nothing */)
default:
Expand All @@ -132,4 +134,4 @@ function createOAuthInterceptor (options) {
}
}

module.exports = { createOAuthInterceptor }
module.exports.createOAuthInterceptor = createOAuthInterceptor
5 changes: 3 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
"name": "undici-oauth-interceptor",
"version": "0.3.0",
"description": "Automatically manage OAuth 2.0 access tokens for Undici requests",
"main": "index.js",
"main": "oauth-interceptor.js",
"scripts": {
"lint": "standard",
"test": "node --test tests/*.test.js"
"test": "borp --coverage"
},
"author": "Platformatic",
"license": "MIT",
Expand All @@ -14,6 +14,7 @@
"undici": "^6.0.0"
},
"devDependencies": {
"borp": "^0.9.1",
"fastify": "^4.24.3",
"standard": "^17.1.0"
}
Expand Down
90 changes: 90 additions & 0 deletions tests/errors.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
'use strict'

const test = require('node:test')
const assert = require('node:assert')
const http = require('node:http')
const { once, EventEmitter } = require('node:events')
const { request, Agent } = require('undici')
const { createDecoder } = require('fast-jwt')
const { createOAuthInterceptor } = require('../')
const { createToken } = require('./helper')

test('error when refreshing', async (t) => {
const accessToken = createToken({ name: 'access' }, { expiresIn: '1ms' })

const mainServer = http.createServer((req, res) => {
assert.fail('should not be called')
})
mainServer.listen(0)

const tokenServer = http.createServer((req, res) => {
assert.strictEqual(req.method, 'POST')
assert.strictEqual(req.url, '/token')

res.writeHead(400)
res.end(JSON.stringify({ message: 'kaboom' }))
})
tokenServer.listen(0)

t.after(() => {
mainServer.close()
tokenServer.close()
})

const refreshToken = createToken(
{ name: 'refresh' },
{ expiresIn: '1d', iss: `http://localhost:${tokenServer.address().port}`, sub: 'client-id' }
)

const dispatcher = new Agent({
interceptors: {
Pool: [createOAuthInterceptor({
accessToken,
refreshToken,
retryOnStatusCodes: [401]
})]
}
})

await assert.rejects(request(`http://localhost:${mainServer.address().port}`, { dispatcher }))
})

test('after service rejects the token, token service reject token, error request', async (t) => {
const accessToken = createToken({ name: 'access' }, { expiresIn: '1d' })
const mainServer = http.createServer((req, res) => {
res.writeHead(401)
return res.end()
})
mainServer.listen(0)

const tokenServer = http.createServer((req, res) => {
assert.strictEqual(req.method, 'POST')
assert.strictEqual(req.url, '/token')

res.writeHead(403)
res.end(JSON.stringify({ message: 'kaboom' }))
})
tokenServer.listen(0)

t.after(() => {
mainServer.close()
tokenServer.close()
})

const refreshToken = createToken(
{ name: 'refresh' },
{ expiresIn: '1d', iss: `http://localhost:${tokenServer.address().port}`, sub: 'client-id' }
)

const dispatcher = new Agent({
interceptors: {
Pool: [createOAuthInterceptor({
accessToken,
refreshToken,
retryOnStatusCodes: [401]
})]
}
})

assert.rejects(request(`http://localhost:${mainServer.address().port}`, { dispatcher }))
})
61 changes: 43 additions & 18 deletions tests/interceptor.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ const http = require('node:http')
const { once, EventEmitter } = require('node:events')
const { request, Agent } = require('undici')
const { createDecoder } = require('fast-jwt')
const { createOAuthInterceptor } = require('../oauth-interceptor')
const { createOAuthInterceptor } = require('../')
const { createToken } = require('./helper')

test('attach provided access token to the request', async (t) => {
Expand Down Expand Up @@ -125,12 +125,12 @@ test('refresh access token if expired', async (t) => {
})

test('refresh token within refresh window', async (t) => {
const ee = new EventEmitter()
let accessToken = createToken({ name: 'access' }, { expiresIn: '29s' })
const oldAccessToken = createToken({ name: 'access' }, { expiresIn: '29s' })
const newAccessToken = createToken({ name: 'access' }, { expiresIn: '1d' })

const mainServer = http.createServer((req, res) => {
assert.ok(req.headers.authorization.length > 'Bearer '.length)
assert.notStrictEqual(req.headers.authorization, `Bearer ${accessToken}`)
assert.strictEqual(req.headers.authorization, `Bearer ${oldAccessToken}`)
res.writeHead(200)
res.end()
})
Expand All @@ -147,10 +147,8 @@ test('refresh token within refresh window', async (t) => {
assert.strictEqual(grant_type, 'refresh_token')
})

accessToken = createToken({ name: 'access' }, { expiresIn: '1d' })
res.writeHead(200)
res.end(JSON.stringify({ access_token: accessToken }))
ee.emit('token-refreshed')
res.end(JSON.stringify({ access_token: newAccessToken }))
})
tokenServer.listen(0)

Expand All @@ -167,14 +165,14 @@ test('refresh token within refresh window', async (t) => {
const dispatcher = new Agent({
interceptors: {
Pool: [createOAuthInterceptor({
accessToken,
accessToken: oldAccessToken,
refreshToken,
retryOnStatusCodes: [401]
})]
}
})

const tokenRefreshed = once(ee, 'token-refreshed')
const tokenRefreshed = once(dispatcher, 'oauth:token-refreshed')

const { statusCode } = await request(`http://localhost:${mainServer.address().port}`, { dispatcher })
assert.strictEqual(statusCode, 200)
Expand Down Expand Up @@ -516,18 +514,18 @@ test('error handling on creation', async (t) => {
* make second request which uses new token
*/
test('optimistic refresh', async (t) => {
const ee = new EventEmitter()
let accessToken = createToken({ name: 'access' }, { expiresIn: '29s' })
const oldAccessToken = createToken({ name: 'access' }, { expiresIn: '29s' })
const newAccessToken = createToken({ name: 'access' }, { expiresIn: '1d' })

let requestCount = 0
const mainServer = http.createServer((req, res) => {
requestCount += 1
if (requestCount === 1) {
assert.ok(req.headers.authorization.length > 'Bearer '.length)
assert.notStrictEqual(req.headers.authorization, `Bearer ${accessToken}`, 'token should not be the same on first request')
assert.strictEqual(req.headers.authorization, `Bearer ${oldAccessToken}`, 'token should be the old one in first request')
} else {
assert.ok(req.headers.authorization.length > 'Bearer '.length)
assert.strictEqual(req.headers.authorization, `Bearer ${accessToken}`, 'token should be the same on second request')
assert.strictEqual(req.headers.authorization, `Bearer ${newAccessToken}`, 'token should be the new one in second request')
}

res.writeHead(200)
Expand All @@ -547,10 +545,8 @@ test('optimistic refresh', async (t) => {
assert.ok(refresh_token)
})

accessToken = createToken({ name: 'access' }, { expiresIn: '1d' })
res.writeHead(200)
res.end(JSON.stringify({ access_token: accessToken }))
ee.emit('token-refreshed')
res.end(JSON.stringify({ access_token: newAccessToken }))
})
tokenServer.listen(0)

Expand All @@ -567,14 +563,14 @@ test('optimistic refresh', async (t) => {
const dispatcher = new Agent({
interceptors: {
Pool: [createOAuthInterceptor({
accessToken,
accessToken: oldAccessToken,
refreshToken,
retryOnStatusCodes: [401]
})]
}
})

const tokenRefreshed = once(ee, 'token-refreshed')
const tokenRefreshed = once(dispatcher, 'oauth:token-refreshed')

{
const { statusCode } = await request(`http://localhost:${mainServer.address().port}`, { dispatcher })
Expand All @@ -590,3 +586,32 @@ test('optimistic refresh', async (t) => {

assert.strictEqual(requestCount, 2)
})

test('do not intercept if not in the origins', async (t) => {
const accessToken = createToken({ name: 'access' }, { expiresIn: '1d' })
const refreshToken = createToken(
{ name: 'refresh' },
{ expiresIn: '1d', iss: 'doesntmatter.com', sub: 'client-id' }
)

const server = http.createServer((req, res) => {
assert.strictEqual(req.headers.authorization, undefined)
res.writeHead(200)
res.end()
})
server.listen(0)

t.after(() => server.close())

const origin = `http://localhost:${server.address().port}`

const dispatcher = new Agent({
interceptors: {
Pool: [createOAuthInterceptor({ accessToken, refreshToken, origins: [origin] })]
}
})

// we use a different origin
const { statusCode } = await request(`http://127.0.0.1:${server.address().port}`, { dispatcher })
assert.strictEqual(statusCode, 200)
})
2 changes: 1 addition & 1 deletion tests/utils.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
const { test } = require('node:test')
const assert = require('node:assert')
const { setGlobalDispatcher, MockAgent } = require('undici')
const { refreshAccessToken } = require('../utils')
const { refreshAccessToken } = require('../lib/utils')

const mockAgent = new MockAgent()
setGlobalDispatcher(mockAgent)
Expand Down

0 comments on commit ecc0e16

Please sign in to comment.