Skip to content

Commit

Permalink
Merge branch 'main' of github.com:platformatic/undici-oauth-dispatch
Browse files Browse the repository at this point in the history
Signed-off-by: Matteo Collina <[email protected]>
  • Loading branch information
mcollina committed Feb 19, 2024
2 parents aff37c3 + 67951b6 commit 80dfec7
Show file tree
Hide file tree
Showing 11 changed files with 284 additions and 78 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ jobs:
matrix:
node-version: [18.x, 20.x]
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4

- name: Use Node.js
uses: actions/setup-node@v2
uses: actions/setup-node@v4
with:
node-version: ${{ matrix.node-version }}

Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# undici-oauth-interceptor

[![NPM version](https://img.shields.io/npm/v/undici-oauth-interceptor.svg?style=flat)](https://www.npmjs.com/package/undici-oauth-interceptor)

Manages an access token and automatically sets the `Authorization` header on any
request that is going to a limited set of domains.

Expand Down
4 changes: 0 additions & 4 deletions index.js

This file was deleted.

4 changes: 2 additions & 2 deletions utils.js → lib/utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

const { request } = require('undici')

async function refreshAccessToken (refreshEndpoint, clientId, refreshToken) {
async function refreshAccessToken ({ refreshEndpoint, refreshToken, clientId }) {
const { statusCode, body } = await request(`${refreshEndpoint}/token`, {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({
token: refreshToken,
refresh_token: refreshToken,
grant_type: 'refresh_token',
client_id: clientId
})
Expand Down
67 changes: 34 additions & 33 deletions oauth-interceptor.js
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
'use strict'

const { createDecoder } = require('fast-jwt')
const { refreshAccessToken } = require('./utils')
const { RetryHandler } = require('undici')
const { refreshAccessToken } = require('./lib/utils')
const { RetryHandler, getGlobalDispatcher } = require('undici')

const decode = createDecoder()
const EXP_DIFF_MS = 10 * 1000
Expand All @@ -19,60 +19,56 @@ function getTokenState (token) {

const { exp } = decode(token)

if (!exp) return TOKEN_STATE.EXPIRED
if (exp <= (Date.now() + EXP_DIFF_MS) / 1000) return TOKEN_STATE.EXPIRED
if (exp <= (Date.now() + NEAR_EXP_DIFF_MS) / 1000) return TOKEN_STATE.NEAR_EXPIRATION
return TOKEN_STATE.VALID
}

let _requestingRefresh
function callRefreshToken (refreshHost, refreshToken, clientId) {
if (_requestingRefresh) return _requestingRefresh

_requestingRefresh = refreshAccessToken(refreshHost, refreshToken, clientId)
_requestingRefresh.catch(() => _requestingRefresh = null)
_requestingRefresh.then((result) => {
_requestingRefresh = null
return result
})

return _requestingRefresh
}

function createOAuthInterceptor (options) {
let { accessToken } = { ...options }
const {
refreshToken,
const { refreshToken, clientId } = options
let {
accessToken ,
retryOnStatusCodes,
origins,
clientId
} = {
retryOnStatusCodes: [401],
origins: [],
refreshToken: '',
...options
}
origins
} = options

retryOnStatusCodes = retryOnStatusCodes || [401]
origins = origins || []

if (!refreshToken) {
throw new Error('refreshToken is required')
}

const { iss, sub } = decode(refreshToken)
const decoded = decode(refreshToken)
const { iss, sub } = decoded
if (!iss) throw new Error('refreshToken is invalid: iss is required')
if (!sub && !clientId) throw new Error('No clientId provided')

const refreshHost = iss
const client = clientId || sub

let _requestingRefresh
function callRefreshToken (refreshEndpoint, refreshToken, clientId) {
if (_requestingRefresh) return _requestingRefresh

_requestingRefresh = refreshAccessToken({ refreshEndpoint, refreshToken, clientId })
.finally(() => _requestingRefresh = null)

return _requestingRefresh
}

return dispatch => {
return function Intercept (opts, handler) {
if (!opts.oauthRetry && (origins.length > 0 && !origins.includes(opts.origin))) {
if (!opts.oauthRetry && !origins.includes(opts.origin)) {
// do not attempt intercept
return dispatch(opts, handler)
}

if (opts.oauthRetry) {
return callRefreshToken(refreshHost, refreshToken, client)
.catch(err => {
handler.onError(err)
})
.then(newAccessToken => {
accessToken = newAccessToken

Expand All @@ -86,7 +82,7 @@ function createOAuthInterceptor (options) {
opts.headers.authorization = `Bearer ${accessToken}`
}

const { dispatcher } = opts
const dispatcher = opts.dispatcher || getGlobalDispatcher()

const retryHandler = new RetryHandler({
...opts,
Expand All @@ -111,18 +107,22 @@ function createOAuthInterceptor (options) {
...opts.headers,
authorization: `Bearer ${accessToken}`
}
dispatcher.emit('oauth:token-refreshed', newAccessToken)
return dispatch(opts, retryHandler)
}

switch (getTokenState(accessToken)) {
case TOKEN_STATE.EXPIRED:
return callRefreshToken(refreshHost, refreshToken, client)
.then(saveTokenAndRetry)
.catch(err => handler.onError(err))
.catch(err => {
handler.onError(err)
})
case TOKEN_STATE.NEAR_EXPIRATION:
callRefreshToken(refreshHost, refreshToken, client)
.then(newAccessToken => {
accessToken = newAccessToken
dispatcher.emit('oauth:token-refreshed', newAccessToken)
})
.catch(/* do nothing */)
default:
Expand All @@ -132,4 +132,5 @@ function createOAuthInterceptor (options) {
}
}

module.exports = { createOAuthInterceptor }
module.exports = createOAuthInterceptor
module.exports.createOAuthInterceptor = createOAuthInterceptor
7 changes: 4 additions & 3 deletions package.json
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
{
"name": "undici-oauth-interceptor",
"version": "0.2.0",
"version": "0.4.2",
"description": "Automatically manage OAuth 2.0 access tokens for Undici requests",
"main": "index.js",
"main": "oauth-interceptor.js",
"scripts": {
"lint": "standard",
"test": "node --test tests/*.test.js"
"test": "borp --coverage"
},
"author": "Platformatic",
"license": "MIT",
Expand All @@ -14,6 +14,7 @@
"undici": "^6.0.0"
},
"devDependencies": {
"borp": "^0.9.1",
"fastify": "^4.24.3",
"standard": "^17.1.0"
}
Expand Down
15 changes: 15 additions & 0 deletions renovate.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"$schema": "https://docs.renovatebot.com/renovate-schema.json",
"extends": [
"config:base"
],
"rangeStrategy": "update-lockfile",
"ignoreDeps": ["camelcase"],
"prHourlyLimit": 10,
"packageRules": [
{
"matchUpdateTypes": ["minor", "patch", "pin", "digest"],
"automerge": true
}
]
}
92 changes: 92 additions & 0 deletions tests/errors.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
'use strict'

const test = require('node:test')
const assert = require('node:assert')
const http = require('node:http')
const { once, EventEmitter } = require('node:events')
const { request, Agent } = require('undici')
const { createDecoder } = require('fast-jwt')
const { createOAuthInterceptor } = require('../')
const { createToken } = require('./helper')

test('error when refreshing', async (t) => {
const accessToken = createToken({ name: 'access' }, { expiresIn: '1ms' })

const mainServer = http.createServer((req, res) => {
assert.fail('should not be called')
})
mainServer.listen(0)

const tokenServer = http.createServer((req, res) => {
assert.strictEqual(req.method, 'POST')
assert.strictEqual(req.url, '/token')

res.writeHead(400)
res.end(JSON.stringify({ message: 'kaboom' }))
})
tokenServer.listen(0)

t.after(() => {
mainServer.close()
tokenServer.close()
})

const refreshToken = createToken(
{ name: 'refresh' },
{ expiresIn: '1d', iss: `http://localhost:${tokenServer.address().port}`, sub: 'client-id' }
)

const dispatcher = new Agent({
interceptors: {
Pool: [createOAuthInterceptor({
accessToken,
refreshToken,
retryOnStatusCodes: [401],
origins: [`http://localhost:${mainServer.address().port}`]
})]
}
})

await assert.rejects(request(`http://localhost:${mainServer.address().port}`, { dispatcher }))
})

test('after service rejects the token, token service reject token, error request', async (t) => {
const accessToken = createToken({ name: 'access' }, { expiresIn: '1d' })
const mainServer = http.createServer((req, res) => {
res.writeHead(401)
return res.end()
})
mainServer.listen(0)

const tokenServer = http.createServer((req, res) => {
assert.strictEqual(req.method, 'POST')
assert.strictEqual(req.url, '/token')

res.writeHead(403)
res.end(JSON.stringify({ message: 'kaboom' }))
})
tokenServer.listen(0)

t.after(() => {
mainServer.close()
tokenServer.close()
})

const refreshToken = createToken(
{ name: 'refresh' },
{ expiresIn: '1d', iss: `http://localhost:${tokenServer.address().port}`, sub: 'client-id' }
)

const dispatcher = new Agent({
interceptors: {
Pool: [createOAuthInterceptor({
accessToken,
refreshToken,
retryOnStatusCodes: [401],
origins: [`http://localhost:${mainServer.address().port}`]
})]
}
})

assert.rejects(request(`http://localhost:${mainServer.address().port}`, { dispatcher }))
})
60 changes: 60 additions & 0 deletions tests/global.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
'use strict'

const test = require('node:test')
const assert = require('node:assert')
const http = require('node:http')
const { once, EventEmitter } = require('node:events')
const { request, Agent, setGlobalDispatcher, getGlobalDispatcher } = require('undici')
const { createDecoder } = require('fast-jwt')
const { createOAuthInterceptor } = require('../')
const { createToken } = require('./helper')

const originalGlobalDispatcher = getGlobalDispatcher()
test.afterEach(() => setGlobalDispatcher(originalGlobalDispatcher))

test('get an access token if no token provided', async (t) => {
let accessToken = ''
const mainServer = http.createServer((req, res) => {
assert.ok(req.headers.authorization.length > 'Bearer '.length)
assert.strictEqual(req.headers.authorization, `Bearer ${accessToken}`)
res.writeHead(200)
res.end()
})
mainServer.listen(0)

const tokenServer = http.createServer((req, res) => {
assert.strictEqual(req.method, 'POST')
assert.strictEqual(req.url, '/token')

accessToken = createToken({ name: 'access' }, { expiresIn: '1d' })
res.writeHead(200)
res.end(JSON.stringify({ access_token: accessToken }))
})
tokenServer.listen(0)

t.after(() => {
mainServer.close()
tokenServer.close()
})

const refreshToken = createToken(
{ name: 'refresh' },
{ expiresIn: '1d', iss: `http://localhost:${tokenServer.address().port}` }
)

const dispatcher = new Agent({
interceptors: {
Pool: [createOAuthInterceptor({
refreshToken,
retryOnStatusCodes: [401],
origins: [`http://localhost:${mainServer.address().port}`],
clientId: 'client-id'
})]
}
})

setGlobalDispatcher(dispatcher)

const { statusCode } = await request(`http://localhost:${mainServer.address().port}`)
assert.strictEqual(statusCode, 200)
})
Loading

0 comments on commit 80dfec7

Please sign in to comment.