@ -19,7 +19,6 @@ import org.junit.After;
@@ -19,7 +19,6 @@ import org.junit.After;
import org.junit.Before ;
import org.junit.Test ;
import org.mockito.ArgumentCaptor ;
import org.springframework.http.HttpStatus ;
import org.springframework.http.converter.HttpMessageConverter ;
import org.springframework.mock.http.client.MockClientHttpResponse ;
@ -44,17 +43,15 @@ import org.springframework.security.oauth2.server.authorization.authentication.O
@@ -44,17 +43,15 @@ import org.springframework.security.oauth2.server.authorization.authentication.O
import org.springframework.security.oauth2.server.authorization.authentication.OAuth2ClientCredentialsAuthenticationToken ;
import org.springframework.security.oauth2.server.authorization.client.RegisteredClient ;
import org.springframework.security.oauth2.server.authorization.client.TestRegisteredClients ;
import org.springframework.util.StringUtils ;
import javax.servlet.FilterChain ;
import javax.servlet.ServletException ;
import javax.servlet.http.HttpServletRequest ;
import javax.servlet.http.HttpServletResponse ;
import java.io.IOException ;
import java.time.Duration ;
import java.time.Instant ;
import java.util.Arrays ;
import java.util.HashSet ;
import java.util.function.Consumer ;
import static org.assertj.core.api.Assertions.assertThat ;
import static org.assertj.core.api.Assertions.assertThatThrownBy ;
@ -140,58 +137,77 @@ public class OAuth2TokenEndpointFilterTests {
@@ -140,58 +137,77 @@ public class OAuth2TokenEndpointFilterTests {
@Test
public void doFilterWhenTokenRequestMissingGrantTypeThenInvalidRequestError ( ) throws Exception {
MockHttpServletRequest request = createAuthorizationCodeTokenRequest (
TestRegisteredClients . registeredClient ( ) . build ( ) ) ;
request . removeParameter ( OAuth2ParameterNames . GRANT_TYPE ) ;
doFilterWhenTokenRequestInvalidParameterThenError (
OAuth2ParameterNames . GRANT_TYPE , OAuth2ErrorCodes . INVALID_REQUEST ,
request - > request . removeParameter ( OAuth2ParameterNames . GRANT_TYPE ) ) ;
OAuth2ParameterNames . GRANT_TYPE , OAuth2ErrorCodes . INVALID_REQUEST , request ) ;
}
@Test
public void doFilterWhenTokenRequestMultipleGrantTypeThenInvalidRequestError ( ) throws Exception {
MockHttpServletRequest request = createAuthorizationCodeTokenRequest (
TestRegisteredClients . registeredClient ( ) . build ( ) ) ;
request . addParameter ( OAuth2ParameterNames . GRANT_TYPE , AuthorizationGrantType . AUTHORIZATION_CODE . getValue ( ) ) ;
doFilterWhenTokenRequestInvalidParameterThenError (
OAuth2ParameterNames . GRANT_TYPE , OAuth2ErrorCodes . INVALID_REQUEST ,
request - > request . addParameter ( OAuth2ParameterNames . GRANT_TYPE , AuthorizationGrantType . AUTHORIZATION_CODE . getValue ( ) ) ) ;
OAuth2ParameterNames . GRANT_TYPE , OAuth2ErrorCodes . INVALID_REQUEST , request ) ;
}
@Test
public void doFilterWhenTokenRequestInvalidGrantTypeThenUnsupportedGrantTypeError ( ) throws Exception {
MockHttpServletRequest request = createAuthorizationCodeTokenRequest (
TestRegisteredClients . registeredClient ( ) . build ( ) ) ;
request . setParameter ( OAuth2ParameterNames . GRANT_TYPE , "invalid-grant-type" ) ;
doFilterWhenTokenRequestInvalidParameterThenError (
OAuth2ParameterNames . GRANT_TYPE , OAuth2ErrorCodes . UNSUPPORTED_GRANT_TYPE ,
request - > request . setParameter ( OAuth2ParameterNames . GRANT_TYPE , "invalid-grant-type" ) ) ;
OAuth2ParameterNames . GRANT_TYPE , OAuth2ErrorCodes . UNSUPPORTED_GRANT_TYPE , request ) ;
}
@Test
public void doFilterWhenTokenRequestMultipleClientIdThenInvalidRequestError ( ) throws Exception {
MockHttpServletRequest request = createAuthorizationCodeTokenRequest (
TestRegisteredClients . registeredClient ( ) . build ( ) ) ;
request . addParameter ( OAuth2ParameterNames . CLIENT_ID , "client-1" ) ;
request . addParameter ( OAuth2ParameterNames . CLIENT_ID , "client-2" ) ;
doFilterWhenTokenRequestInvalidParameterThenError (
OAuth2ParameterNames . CLIENT_ID , OAuth2ErrorCodes . INVALID_REQUEST ,
request - > {
request . addParameter ( OAuth2ParameterNames . CLIENT_ID , "client-1" ) ;
request . addParameter ( OAuth2ParameterNames . CLIENT_ID , "client-2" ) ;
} ) ;
OAuth2ParameterNames . CLIENT_ID , OAuth2ErrorCodes . INVALID_REQUEST , request ) ;
}
@Test
public void doFilterWhenTokenRequestMissingCodeThenInvalidRequestError ( ) throws Exception {
MockHttpServletRequest request = createAuthorizationCodeTokenRequest (
TestRegisteredClients . registeredClient ( ) . build ( ) ) ;
request . removeParameter ( OAuth2ParameterNames . CODE ) ;
doFilterWhenTokenRequestInvalidParameterThenError (
OAuth2ParameterNames . CODE , OAuth2ErrorCodes . INVALID_REQUEST ,
request - > request . removeParameter ( OAuth2ParameterNames . CODE ) ) ;
OAuth2ParameterNames . CODE , OAuth2ErrorCodes . INVALID_REQUEST , request ) ;
}
@Test
public void doFilterWhenTokenRequestMultipleCodeThenInvalidRequestError ( ) throws Exception {
MockHttpServletRequest request = createAuthorizationCodeTokenRequest (
TestRegisteredClients . registeredClient ( ) . build ( ) ) ;
request . addParameter ( OAuth2ParameterNames . CODE , "code-2" ) ;
doFilterWhenTokenRequestInvalidParameterThenError (
OAuth2ParameterNames . CODE , OAuth2ErrorCodes . INVALID_REQUEST ,
request - > request . addParameter ( OAuth2ParameterNames . CODE , "code-2" ) ) ;
OAuth2ParameterNames . CODE , OAuth2ErrorCodes . INVALID_REQUEST , request ) ;
}
@Test
public void doFilterWhenTokenRequestMultipleRedirectUriThenInvalidRequestError ( ) throws Exception {
MockHttpServletRequest request = createAuthorizationCodeTokenRequest (
TestRegisteredClients . registeredClient ( ) . build ( ) ) ;
request . addParameter ( OAuth2ParameterNames . REDIRECT_URI , "https://example2.com" ) ;
doFilterWhenTokenRequestInvalidParameterThenError (
OAuth2ParameterNames . REDIRECT_URI , OAuth2ErrorCodes . INVALID_REQUEST ,
request - > request . addParameter ( OAuth2ParameterNames . REDIRECT_URI , "https://example2.com" ) ) ;
OAuth2ParameterNames . REDIRECT_URI , OAuth2ErrorCodes . INVALID_REQUEST , request ) ;
}
@Test
public void doFilterWhenTokenRequestValidThenAccessTokenResponse ( ) throws Exception {
public void doFilterWhenAuthorizationCode TokenRequestValidThenAccessTokenResponse ( ) throws Exception {
RegisteredClient registeredClient = TestRegisteredClients . registeredClient ( ) . build ( ) ;
Authentication clientPrincipal = new OAuth2ClientAuthenticationToken ( registeredClient ) ;
OAuth2AccessToken accessToken = new OAuth2AccessToken (
@ -208,7 +224,7 @@ public class OAuth2TokenEndpointFilterTests {
@@ -208,7 +224,7 @@ public class OAuth2TokenEndpointFilterTests {
securityContext . setAuthentication ( clientPrincipal ) ;
SecurityContextHolder . setContext ( securityContext ) ;
MockHttpServletRequest request = createTokenRequest ( registeredClient ) ;
MockHttpServletRequest request = createAuthorizationCode TokenRequest ( registeredClient ) ;
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
FilterChain filterChain = mock ( FilterChain . class ) ;
@ -242,38 +258,24 @@ public class OAuth2TokenEndpointFilterTests {
@@ -242,38 +258,24 @@ public class OAuth2TokenEndpointFilterTests {
}
@Test
public void doFilterWhenGrantTypeIsClientCredentialsThenAuthenticateWithClientCredentialsToken ( ) throws ServletException , IO Exception {
RegisteredClient registeredClient = TestRegisteredClients . registeredClient ( ) . build ( ) ;
doFilterForClientCredentialsGrant ( registeredClient , null ) ;
public void doFilterWhenTokenRequestMultipleScopeThenInvalidRequestError ( ) throws Exception {
RegisteredClient registeredClient = TestRegisteredClients . registeredClient2 ( ) . build ( ) ;
Authentication clientPrincipal = new OAuth2ClientAuthenticationToken ( registeredClient ) ;
ArgumentCaptor < Authentication > captor = ArgumentCaptor . forClass ( Authentication . class ) ;
verify ( this . authenticationManager ) . authenticate ( captor . capture ( ) ) ;
SecurityContext securityContext = SecurityContextHolder . createEmptyContext ( ) ;
securityContext . setAuthentication ( clientPrincipal ) ;
SecurityContextHolder . setContext ( securityContext ) ;
assertThat ( captor . getValue ( ) ) . isInstanceOf ( OAuth2ClientCredentialsAuthenticationToken . class ) ;
OAuth2ClientCredentialsAuthenticationToken clientAuthenticationToken = ( OAuth2ClientCredentialsAuthenticationToken ) captor . getValue ( ) ;
MockHttpServletRequest request = createClientCredentialsTokenRequest ( registeredClient ) ;
request . addParameter ( OAuth2ParameterNames . SCOPE , "profile" ) ;
assertThat ( clientAuthenticationToken . getPrincipal ( ) ) . isEqualTo ( new OAuth2ClientAuthenticationToken ( registeredClient ) ) ;
doFilterWhenTokenRequestInvalidParameterThenError (
OAuth2ParameterNames . SCOPE , OAuth2ErrorCodes . INVALID_REQUEST , request ) ;
}
@Test
public void doFilterWhenGrantTypeIsClientCredentialsWithScopeThenIncludeScopeInResponse ( ) throws ServletException , IOException {
RegisteredClient registeredClient = TestRegisteredClients . registeredClient ( ) . build ( ) ;
doFilterForClientCredentialsGrant ( registeredClient , "openid email" ) ;
ArgumentCaptor < Authentication > captor = ArgumentCaptor . forClass ( Authentication . class ) ;
verify ( this . authenticationManager ) . authenticate ( captor . capture ( ) ) ;
assertThat ( captor . getValue ( ) ) . isInstanceOf ( OAuth2ClientCredentialsAuthenticationToken . class ) ;
OAuth2ClientCredentialsAuthenticationToken clientAuthenticationToken = ( OAuth2ClientCredentialsAuthenticationToken ) captor . getValue ( ) ;
HashSet < String > expectedScopes = new HashSet < > ( ) ;
expectedScopes . add ( "openid" ) ;
expectedScopes . add ( "email" ) ;
assertThat ( clientAuthenticationToken . getScopes ( ) ) . isEqualTo ( expectedScopes ) ;
}
private void doFilterForClientCredentialsGrant ( RegisteredClient registeredClient , String scope ) throws ServletException , IOException {
public void doFilterWhenClientCredentialsTokenRequestValidThenAccessTokenResponse ( ) throws Exception {
RegisteredClient registeredClient = TestRegisteredClients . registeredClient2 ( ) . build ( ) ;
Authentication clientPrincipal = new OAuth2ClientAuthenticationToken ( registeredClient ) ;
OAuth2AccessToken accessToken = new OAuth2AccessToken (
OAuth2AccessToken . TokenType . BEARER , "token" ,
@ -282,35 +284,46 @@ public class OAuth2TokenEndpointFilterTests {
@@ -282,35 +284,46 @@ public class OAuth2TokenEndpointFilterTests {
OAuth2AccessTokenAuthenticationToken accessTokenAuthentication =
new OAuth2AccessTokenAuthenticationToken (
registeredClient , clientPrincipal , accessToken ) ;
final String clientId = registeredClient . getClientId ( ) ;
final String clientSecret = registeredClient . getClientSecret ( ) ;
MockHttpServletRequest request = new MockHttpServletRequest ( "POST" , OAuth2TokenEndpointFilter . DEFAULT_TOKEN_ENDPOINT_URI ) ;
request . setServletPath ( OAuth2TokenEndpointFilter . DEFAULT_TOKEN_ENDPOINT_URI ) ;
request . addParameter ( "client_id" , clientId ) ;
request . addParameter ( "client_secret" , clientSecret ) ;
request . addParameter ( "grant_type" , AuthorizationGrantType . CLIENT_CREDENTIALS . getValue ( ) ) ;
if ( scope ! = null ) {
request . addParameter ( "scope" , scope ) ;
}
when ( this . authenticationManager . authenticate ( any ( ) ) ) . thenReturn ( accessTokenAuthentication ) ;
SecurityContext context = SecurityContextHolder . createEmptyContext ( ) ;
context . setAuthentication ( new OAuth2ClientAuthenticationToken ( registeredClient ) ) ;
SecurityContextHolder . setContext ( context ) ;
SecurityContext securityContext = SecurityContextHolder . createEmptyContext ( ) ;
securityContext . setAuthentication ( clientPrincipal ) ;
SecurityContextHolder . setContext ( securityContext ) ;
MockHttpServletRequest request = createClientCredentialsTokenRequest ( registeredClient ) ;
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
filter . doFilter ( request , response , mock ( FilterChain . class ) ) ;
FilterChain filterChain = mock ( FilterChain . class ) ;
this . filter . doFilter ( request , response , filterChain ) ;
verifyNoInteractions ( filterChain ) ;
ArgumentCaptor < OAuth2ClientCredentialsAuthenticationToken > clientCredentialsAuthenticationCaptor =
ArgumentCaptor . forClass ( OAuth2ClientCredentialsAuthenticationToken . class ) ;
verify ( this . authenticationManager ) . authenticate ( clientCredentialsAuthenticationCaptor . capture ( ) ) ;
OAuth2ClientCredentialsAuthenticationToken clientCredentialsAuthentication =
clientCredentialsAuthenticationCaptor . getValue ( ) ;
assertThat ( clientCredentialsAuthentication . getPrincipal ( ) ) . isEqualTo ( clientPrincipal ) ;
assertThat ( clientCredentialsAuthentication . getScopes ( ) ) . isEqualTo ( registeredClient . getScopes ( ) ) ;
assertThat ( response . getStatus ( ) ) . isEqualTo ( HttpStatus . OK . value ( ) ) ;
OAuth2AccessTokenResponse accessTokenResponse = readAccessTokenResponse ( response ) ;
OAuth2AccessToken accessTokenResult = accessTokenResponse . getAccessToken ( ) ;
assertThat ( accessTokenResult . getTokenType ( ) ) . isEqualTo ( accessToken . getTokenType ( ) ) ;
assertThat ( accessTokenResult . getTokenValue ( ) ) . isEqualTo ( accessToken . getTokenValue ( ) ) ;
assertThat ( accessTokenResult . getIssuedAt ( ) ) . isBetween (
accessToken . getIssuedAt ( ) . minusSeconds ( 1 ) , accessToken . getIssuedAt ( ) . plusSeconds ( 1 ) ) ;
assertThat ( accessTokenResult . getExpiresAt ( ) ) . isBetween (
accessToken . getExpiresAt ( ) . minusSeconds ( 1 ) , accessToken . getExpiresAt ( ) . plusSeconds ( 1 ) ) ;
assertThat ( accessTokenResult . getScopes ( ) ) . isEqualTo ( accessToken . getScopes ( ) ) ;
}
private void doFilterWhenTokenRequestInvalidParameterThenError ( String parameterName , String errorCode ,
Consumer < MockHttpServletRequest > requestConsumer ) throws Exception {
MockHttpServletRequest request ) throws Exception {
RegisteredClient registeredClient = TestRegisteredClients . registeredClient ( ) . build ( ) ;
MockHttpServletRequest request = createTokenRequest ( registeredClient ) ;
requestConsumer . accept ( request ) ;
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
FilterChain filterChain = mock ( FilterChain . class ) ;
@ -336,7 +349,7 @@ public class OAuth2TokenEndpointFilterTests {
@@ -336,7 +349,7 @@ public class OAuth2TokenEndpointFilterTests {
return this . accessTokenHttpResponseConverter . read ( OAuth2AccessTokenResponse . class , httpResponse ) ;
}
private static MockHttpServletRequest createTokenRequest ( RegisteredClient registeredClient ) {
private static MockHttpServletRequest createAuthorizationCode TokenRequest ( RegisteredClient registeredClient ) {
String [ ] redirectUris = registeredClient . getRedirectUris ( ) . toArray ( new String [ 0 ] ) ;
String requestUri = OAuth2TokenEndpointFilter . DEFAULT_TOKEN_ENDPOINT_URI ;
@ -349,4 +362,16 @@ public class OAuth2TokenEndpointFilterTests {
@@ -349,4 +362,16 @@ public class OAuth2TokenEndpointFilterTests {
return request ;
}
private static MockHttpServletRequest createClientCredentialsTokenRequest ( RegisteredClient registeredClient ) {
String requestUri = OAuth2TokenEndpointFilter . DEFAULT_TOKEN_ENDPOINT_URI ;
MockHttpServletRequest request = new MockHttpServletRequest ( "POST" , requestUri ) ;
request . setServletPath ( requestUri ) ;
request . addParameter ( OAuth2ParameterNames . GRANT_TYPE , AuthorizationGrantType . CLIENT_CREDENTIALS . getValue ( ) ) ;
request . addParameter ( OAuth2ParameterNames . SCOPE ,
StringUtils . collectionToDelimitedString ( registeredClient . getScopes ( ) , " " ) ) ;
return request ;
}
}