Skip to content

Commit

Permalink
nullability and sealing (#840)
Browse files Browse the repository at this point in the history
* nullability and sealing

* nullability and sealing

* nullability and sealing

* nullability and sealing
  • Loading branch information
jan-olaveide authored Jan 22, 2024
1 parent 093a28e commit e20f7d1
Show file tree
Hide file tree
Showing 15 changed files with 69 additions and 57 deletions.
3 changes: 1 addition & 2 deletions token-client-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
<scope>test</scope>
<version>2.16.1</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
Expand All @@ -44,7 +44,6 @@
<groupId>com.fasterxml.jackson.module</groupId>
<artifactId>jackson-module-kotlin</artifactId>
<version>2.16.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.jetbrains.kotlin</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import com.nimbusds.oauth2.sdk.GrantType.JWT_BEARER
import com.nimbusds.oauth2.sdk.GrantType.TOKEN_EXCHANGE
import com.nimbusds.oauth2.sdk.ParseException
import com.nimbusds.oauth2.sdk.`as`.AuthorizationServerMetadata
import com.nimbusds.oauth2.sdk.`as`.AuthorizationServerMetadata.*
import java.io.IOException
import java.net.URI

Expand Down Expand Up @@ -41,7 +42,7 @@ class ClientProperties @JvmOverloads constructor(var tokenEndpointUrl: URI? = nu

private fun endpointUrlFromMetadata(wellKnown: URI?) =
runCatching {
wellKnown?.let { AuthorizationServerMetadata.parse(DefaultResourceRetriever().retrieveResource(wellKnown.toURL()).content).tokenEndpointURI }
wellKnown?.let { parse(DefaultResourceRetriever().retrieveResource(wellKnown.toURL()).content).tokenEndpointURI }
?: throw OAuth2ClientException("Well-known url cannot be null, please check your configuration")
}.getOrElse {
when(it) {
Expand Down Expand Up @@ -73,8 +74,9 @@ class ClientProperties @JvmOverloads constructor(var tokenEndpointUrl: URI? = nu
}


class TokenExchangeProperties @JvmOverloads constructor(val audience: String, var resource: String? = null) {

fun subjectTokenType() = "urn:ietf:params:oauth:token-type:jwt"
data class TokenExchangeProperties @JvmOverloads constructor(val audience: String, var resource: String? = null) {
companion object {
const val SUBJECT_TOKEN_TYPE_VALUE = "urn:ietf:params:oauth:token-type:jwt"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ object OAuth2CacheFactory {
private fun <T> evictOnResponseExpiresIn(skewInSeconds : Long) : Expiry<T, OAuth2AccessTokenResponse> {
return object : Expiry<T, OAuth2AccessTokenResponse> {
override fun expireAfterCreate(key : T, response : OAuth2AccessTokenResponse, currentTime : Long) =
SECONDS.toNanos(if (response.expiresIn!! > skewInSeconds) response.expiresIn!! - skewInSeconds else response.expiresIn!!.toLong())
SECONDS.toNanos(if (response.expiresIn!! > skewInSeconds) response.expiresIn - skewInSeconds else response.expiresIn.toLong())
override fun expireAfterUpdate(key : T, response : OAuth2AccessTokenResponse, currentTime : Long, currentDuration : Long) = currentDuration
override fun expireAfterRead(key : T, response : OAuth2AccessTokenResponse, currentTime : Long, currentDuration : Long) = currentDuration
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ class SimpleOAuth2HttpClient : OAuth2HttpClient {
.sendRequest()
.processResponse()

private fun HttpRequest.Builder.configureRequest(request: OAuth2HttpRequest): HttpRequest.Builder {
request.oAuth2HttpHeaders.headers.forEach { (key, values) -> values.forEach { header(key, it) } }
uri(request.tokenEndpointUrl)
POST(BodyPublishers.ofString(request.formParameters.toUrlEncodedString()))
private fun HttpRequest.Builder.configureRequest(req: OAuth2HttpRequest): HttpRequest.Builder {
req.oAuth2HttpHeaders.headers.forEach { (key, values) -> values.forEach { header(key, it) } }
uri(req.tokenEndpointUrl)
POST(BodyPublishers.ofString(req.formParameters.toUrlEncodedString()))
return this
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ abstract class AbstractOAuth2GrantRequest(val grantType : GrantType, val clientP
return grantType == that.grantType && clientProperties == that.clientProperties
}

fun scopes() = clientProperties.scope.joinToString(" ")

override fun hashCode() = Objects.hash(grantType, clientProperties)
override fun toString() = "${javaClass.getSimpleName()} [oAuth2GrantType=$grantType, clientProperties=$clientProperties]"
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import no.nav.security.token.support.client.core.http.OAuth2HttpClient
import no.nav.security.token.support.client.core.http.OAuth2HttpHeaders
import no.nav.security.token.support.client.core.http.OAuth2HttpRequest

abstract class AbstractOAuth2TokenClient<T : AbstractOAuth2GrantRequest> internal constructor(private val oAuth2HttpClient : OAuth2HttpClient) {
sealed class AbstractOAuth2TokenClient<T : AbstractOAuth2GrantRequest>(private val oAuth2HttpClient : OAuth2HttpClient) {

protected abstract fun formParameters(grantRequest : T) : Map<String, String>

Expand Down Expand Up @@ -57,11 +57,11 @@ abstract class AbstractOAuth2TokenClient<T : AbstractOAuth2GrantRequest> interna
}

private fun defaultFormParameters(grantRequest : T) : MutableMap<String, String> =
with(grantRequest.clientProperties) {
defaultClientAuthenticationFormParameters(grantRequest).apply {
put(GRANT_TYPE,grantRequest.grantType.value)
with(grantRequest) {
defaultClientAuthenticationFormParameters(this).apply {
put(GRANT_TYPE,grantType.value)
if (TOKEN_EXCHANGE != grantType) {
put(SCOPE, join(" ", scope))
put(SCOPE, scopes())
}
}
}
Expand Down Expand Up @@ -92,6 +92,7 @@ abstract class AbstractOAuth2TokenClient<T : AbstractOAuth2GrantRequest> interna
}
}


override fun toString() = "${javaClass.getSimpleName()} [oAuth2HttpClient=$oAuth2HttpClient]"

}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ import no.nav.security.token.support.client.core.http.OAuth2HttpClient
class ClientCredentialsTokenClient(oAuth2HttpClient : OAuth2HttpClient) : AbstractOAuth2TokenClient<ClientCredentialsGrantRequest>(oAuth2HttpClient) {

override fun formParameters(grantRequest : ClientCredentialsGrantRequest) = LinkedHashMap<String, String>().apply {
put(SCOPE, grantRequest.clientProperties.scope.joinToString(" "))
put(SCOPE, grantRequest.scopes())
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package no.nav.security.token.support.client.core.oauth2

data class OAuth2AccessTokenResponse (var access_token : String? = null, var expires_at : Int? = null, var expires_in : Int? = 60, private val additionalParameters : Map<String, Any> = emptyMap()) {
val accessToken get() = access_token
val expiresAt get() = expires_at
val expiresIn get() = expires_in
val accessToken = access_token
val expiresAt = expires_at
val expiresIn = expires_in
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ class OAuth2AccessTokenService @JvmOverloads constructor(private val tokenResolv



fun getAccessToken(clientProperties : ClientProperties) : OAuth2AccessTokenResponse? {
log.trace("Getting access_token for grant={}", clientProperties.grantType)
return when (clientProperties.grantType) {
JWT_BEARER -> executeOnBehalfOf(clientProperties)
CLIENT_CREDENTIALS -> executeClientCredentials(clientProperties)
TOKEN_EXCHANGE -> executeTokenExchange(clientProperties)
else -> throw OAuth2ClientException("Invalid grant-type ${clientProperties.grantType.value} from OAuth2ClientConfig.OAuth2Client. grant-type not in supported grant-types ($SUPPORTED_GRANT_TYPES)")
fun getAccessToken(p : ClientProperties) : OAuth2AccessTokenResponse {
return when (p.grantType) {
JWT_BEARER -> executeOnBehalfOf(p)
CLIENT_CREDENTIALS -> executeClientCredentials(p)
TOKEN_EXCHANGE -> executeTokenExchange(p)
else -> throw OAuth2ClientException("Invalid grant-type ${p.grantType.value} from OAuth2ClientConfig.OAuth2Client. grant-type not in supported grant-types ($SUPPORTED_GRANT_TYPES)")
}.also {
log.debug("Got access_token for grant={}", p.grantType)
}
}

Expand All @@ -51,7 +52,7 @@ class OAuth2AccessTokenService @JvmOverloads constructor(private val tokenResolv
private val SUPPORTED_GRANT_TYPES = listOf(JWT_BEARER, CLIENT_CREDENTIALS, TOKEN_EXCHANGE
)
private val log = LoggerFactory.getLogger(OAuth2AccessTokenService::class.java)
private fun <T : AbstractOAuth2GrantRequest?> getFromCacheIfEnabled(grantRequest : T, cache : Cache<T, OAuth2AccessTokenResponse>?, client : Function<T, OAuth2AccessTokenResponse?>) =
private fun <T : AbstractOAuth2GrantRequest?> getFromCacheIfEnabled(grantRequest : T, cache : Cache<T, OAuth2AccessTokenResponse>?, client : Function<T, OAuth2AccessTokenResponse>) =
cache?.let {
log.debug("Cache is enabled so attempt to get from cache or update cache if not present.")
cache[grantRequest, client]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class OnBehalfOfTokenClient(oAuth2HttpClient : OAuth2HttpClient) : AbstractOAuth
LinkedHashMap<String, String>().apply {
put(ASSERTION, grantRequest.assertion)
put(REQUESTED_TOKEN_USE,REQUESTED_TOKEN_USE_VALUE)
put(SCOPE, grantRequest.clientProperties.scope.joinToString(" "))
put(SCOPE, grantRequest.scopes())

}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,27 @@
package no.nav.security.token.support.client.core.oauth2

import com.nimbusds.oauth2.sdk.GrantType
import no.nav.security.token.support.client.core.ClientProperties.TokenExchangeProperties.Companion.SUBJECT_TOKEN_TYPE_VALUE
import no.nav.security.token.support.client.core.OAuth2ClientException
import no.nav.security.token.support.client.core.OAuth2ParameterNames
import no.nav.security.token.support.client.core.OAuth2ParameterNames.AUDIENCE
import no.nav.security.token.support.client.core.OAuth2ParameterNames.RESOURCE
import no.nav.security.token.support.client.core.OAuth2ParameterNames.SCOPE
import no.nav.security.token.support.client.core.OAuth2ParameterNames.SUBJECT_TOKEN
import no.nav.security.token.support.client.core.OAuth2ParameterNames.SUBJECT_TOKEN_TYPE
import no.nav.security.token.support.client.core.http.OAuth2HttpClient

class TokenExchangeClient(oAuth2HttpClient : OAuth2HttpClient) : AbstractOAuth2TokenClient<TokenExchangeGrantRequest>(oAuth2HttpClient) {

override fun formParameters(grantRequest : TokenExchangeGrantRequest) =
LinkedHashMap<String, String>().apply {
grantRequest.clientProperties.tokenExchange.run {
put(SUBJECT_TOKEN_TYPE, this!!.subjectTokenType())
put(SUBJECT_TOKEN,grantRequest.subjectToken)
put(AUDIENCE, audience)
resource?.takeIf { it.isNotEmpty() }?.let { put(RESOURCE, it) }
with(grantRequest) {
HashMap<String, String>().apply {
clientProperties.tokenExchange?.run {
put(SUBJECT_TOKEN_TYPE, SUBJECT_TOKEN_TYPE_VALUE)
put(SUBJECT_TOKEN,subjectToken)
put(AUDIENCE, audience)
resource?.takeIf { it.isNotEmpty() }?.let { put(RESOURCE, it) }
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ internal class OAuth2AccessTokenServiceTest {
val res = oAuth2AccessTokenService.getAccessToken(onBehalfOfProperties())
verify(onBehalfOfTokenResponseClient).getTokenResponse(reifiedAny( OnBehalfOfGrantRequest::class.java))
assertThat(res).hasNoNullFieldsOrProperties()
assertThat(res!!.accessToken).isEqualTo("first_access_token")
assertThat(res.accessToken).isEqualTo("first_access_token")
}

@Test
Expand All @@ -76,7 +76,7 @@ internal class OAuth2AccessTokenServiceTest {
val res = oAuth2AccessTokenService.getAccessToken(clientCredentialsProperties())
verify(clientCredentialsTokenResponseClient).getTokenResponse(reifiedAny(ClientCredentialsGrantRequest::class.java))
assertThat(res).hasNoNullFieldsOrProperties()
assertThat(res!!.accessToken).isEqualTo("first_access_token")
assertThat(res.accessToken).isEqualTo("first_access_token")
}

@Test
Expand All @@ -98,13 +98,13 @@ internal class OAuth2AccessTokenServiceTest {
val res = oAuth2AccessTokenService.getAccessToken(clientProperties)
verify(onBehalfOfTokenResponseClient).getTokenResponse(reifiedAny(OnBehalfOfGrantRequest::class.java))
assertThat(res).hasNoNullFieldsOrProperties()
assertThat(res!!.accessToken).isEqualTo("first_access_token")
assertThat(res.accessToken).isEqualTo("first_access_token")

//should get response from cache and NOT invoke client
reset(onBehalfOfTokenResponseClient)
val res2 = oAuth2AccessTokenService.getAccessToken(clientProperties)
verify(onBehalfOfTokenResponseClient, never()).getTokenResponse(reifiedAny(OnBehalfOfGrantRequest::class.java))
assertThat(res2!!.accessToken).isEqualTo("first_access_token")
assertThat(res2.accessToken).isEqualTo("first_access_token")

//another user/token but same clientconfig, should invoke client and populate cache
reset(assertionResolver)
Expand All @@ -115,7 +115,7 @@ internal class OAuth2AccessTokenServiceTest {
.thenReturn(accessTokenResponse(secondAccessToken, 60))
val res3 = oAuth2AccessTokenService.getAccessToken(clientProperties)
verify(onBehalfOfTokenResponseClient).getTokenResponse(reifiedAny(OnBehalfOfGrantRequest::class.java))
assertThat(res3!!.accessToken).isEqualTo(secondAccessToken)
assertThat(res3.accessToken).isEqualTo(secondAccessToken)
}

@Test
Expand All @@ -130,14 +130,14 @@ internal class OAuth2AccessTokenServiceTest {
val res1 = oAuth2AccessTokenService.getAccessToken(clientProperties)
verify(clientCredentialsTokenResponseClient).getTokenResponse(reifiedAny(ClientCredentialsGrantRequest::class.java))
assertThat(res1).hasNoNullFieldsOrProperties()
assertThat(res1!!.accessToken).isEqualTo("first_access_token")
assertThat(res1.accessToken).isEqualTo("first_access_token")

//should get response from cache and NOT invoke client
reset(clientCredentialsTokenResponseClient)
val res2 = oAuth2AccessTokenService.getAccessToken(clientProperties)
verify(clientCredentialsTokenResponseClient, never()).getTokenResponse(reifiedAny(
ClientCredentialsGrantRequest::class.java))
assertThat(res2!!.accessToken).isEqualTo("first_access_token")
assertThat(res2.accessToken).isEqualTo("first_access_token")

//another clientconfig, should invoke client and populate cache
clientProperties = clientCredentialsProperties("scope3")
Expand All @@ -147,7 +147,7 @@ internal class OAuth2AccessTokenServiceTest {
.thenReturn(accessTokenResponse(secondAccessToken, 60))
val res3 = oAuth2AccessTokenService.getAccessToken(clientProperties)
verify(clientCredentialsTokenResponseClient).getTokenResponse(reifiedAny(ClientCredentialsGrantRequest::class.java))
assertThat(res3!!.accessToken).isEqualTo(secondAccessToken)
assertThat(res3.accessToken).isEqualTo(secondAccessToken)
}

@Test
Expand All @@ -163,7 +163,7 @@ internal class OAuth2AccessTokenServiceTest {
val res1 = oAuth2AccessTokenService.getAccessToken(clientProperties)
verify(onBehalfOfTokenResponseClient).getTokenResponse(reifiedAny(OnBehalfOfGrantRequest::class.java))
assertThat(res1).hasNoNullFieldsOrProperties()
assertThat(res1!!.accessToken).isEqualTo("first_access_token")
assertThat(res1.accessToken).isEqualTo("first_access_token")
Thread.sleep(1000)

//entry should be missing from cache due to expiry
Expand All @@ -173,7 +173,7 @@ internal class OAuth2AccessTokenServiceTest {
.thenReturn(accessTokenResponse(secondAccessToken, 1))
val res2 = oAuth2AccessTokenService.getAccessToken(clientProperties)
verify(onBehalfOfTokenResponseClient).getTokenResponse(reifiedAny(OnBehalfOfGrantRequest::class.java))
assertThat(res2!!.accessToken).isEqualTo(secondAccessToken)
assertThat(res2.accessToken).isEqualTo(secondAccessToken)
}

@Test
Expand All @@ -188,7 +188,7 @@ internal class OAuth2AccessTokenServiceTest {
verify(exchangeTokeResponseClient, times(1)).getTokenResponse(reifiedAny(
TokenExchangeGrantRequest::class.java))
assertThat(res1).hasNoNullFieldsOrProperties()
assertThat(res1!!.accessToken).isEqualTo("first_access_token")
assertThat(res1.accessToken).isEqualTo("first_access_token")
}

companion object {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ data class OAuth2CacheConfig(val enabled: Boolean, val maximumSize: Long = 1000,
Caffeine.newBuilder()
.expireAfter(evictOnResponseExpiresIn(evictSkew))
.maximumSize(maximumSize)
.buildAsync { key: GrantRequest, _ ->
.buildAsync { key, _ ->
cacheContext.future {
loader(key)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class OAuth2ClientRequestInterceptor(private val properties: ClientConfiguration
private val matcher: ClientConfigurationPropertiesMatcher) : ClientHttpRequestInterceptor {
override fun intercept(req: HttpRequest, body: ByteArray, execution: ClientHttpRequestExecution): ClientHttpResponse {
matcher.findProperties(properties, req.uri)?.let {
service.getAccessToken(it)?.accessToken?.let { token -> req.headers.setBearerAuth(token) }
service.getAccessToken(it).accessToken?.let { token -> req.headers.setBearerAuth(token) }
}
return execution.execute(req, body)
}
Expand Down
Loading

0 comments on commit e20f7d1

Please sign in to comment.