@ -44,9 +44,12 @@ import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
@@ -44,9 +44,12 @@ import org.springframework.security.oauth2.core.OAuth2ErrorCodes;
import org.springframework.security.oauth2.core.OAuth2RefreshToken ;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationExchange ;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest ;
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationResponse ;
import org.springframework.security.oauth2.core.endpoint.OAuth2ParameterNames ;
import org.springframework.security.oauth2.core.user.OAuth2User ;
import org.springframework.security.web.authentication.AuthenticationFailureHandler ;
import org.springframework.security.web.util.UrlUtils ;
import org.springframework.web.util.UriComponentsBuilder ;
import javax.servlet.FilterChain ;
import javax.servlet.http.HttpServletRequest ;
@ -64,7 +67,7 @@ import static org.mockito.Mockito.*;
@@ -64,7 +67,7 @@ import static org.mockito.Mockito.*;
* @author Joe Grandja
* /
@PowerMockIgnore ( "javax.security.*" )
@PrepareForTest ( { OAuth2AuthorizationRequest . class , OAuth2Authorization Exchange . class , OAuth2LoginAuthenticationFilter . class } )
@PrepareForTest ( { OAuth2AuthorizationExchange . class , OAuth2LoginAuthenticationFilter . class } )
@RunWith ( PowerMockRunner . class )
public class OAuth2LoginAuthenticationFilterTests {
private ClientRegistration registration1 ;
@ -298,16 +301,137 @@ public class OAuth2LoginAuthenticationFilterTests {
@@ -298,16 +301,137 @@ public class OAuth2LoginAuthenticationFilterTests {
verify ( this . filter ) . attemptAuthentication ( any ( HttpServletRequest . class ) , any ( HttpServletResponse . class ) ) ;
}
// gh-5890
@Test
public void doFilterWhenAuthorizationResponseHasDefaultPort80ThenRedirectUriMatchingExcludesPort ( ) throws Exception {
String requestUri = "/login/oauth2/code/" + this . registration2 . getRegistrationId ( ) ;
String state = "state" ;
MockHttpServletRequest request = new MockHttpServletRequest ( "GET" , requestUri ) ;
request . setScheme ( "http" ) ;
request . setServerName ( "example.com" ) ;
request . setServerPort ( 80 ) ;
request . setServletPath ( requestUri ) ;
request . addParameter ( OAuth2ParameterNames . CODE , "code" ) ;
request . addParameter ( OAuth2ParameterNames . STATE , "state" ) ;
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
FilterChain filterChain = mock ( FilterChain . class ) ;
this . setUpAuthorizationRequest ( request , response , this . registration2 , state ) ;
this . setUpAuthenticationResult ( this . registration2 ) ;
this . filter . doFilter ( request , response , filterChain ) ;
ArgumentCaptor < Authentication > authenticationArgCaptor = ArgumentCaptor . forClass ( Authentication . class ) ;
verify ( this . authenticationManager ) . authenticate ( authenticationArgCaptor . capture ( ) ) ;
OAuth2LoginAuthenticationToken authentication = ( OAuth2LoginAuthenticationToken ) authenticationArgCaptor . getValue ( ) ;
OAuth2AuthorizationRequest authorizationRequest = authentication . getAuthorizationExchange ( ) . getAuthorizationRequest ( ) ;
OAuth2AuthorizationResponse authorizationResponse = authentication . getAuthorizationExchange ( ) . getAuthorizationResponse ( ) ;
String expectedRedirectUri = "http://example.com/login/oauth2/code/registration-id-2" ;
assertThat ( authorizationRequest . getRedirectUri ( ) ) . isEqualTo ( expectedRedirectUri ) ;
assertThat ( authorizationResponse . getRedirectUri ( ) ) . isEqualTo ( expectedRedirectUri ) ;
}
// gh-5890
@Test
public void doFilterWhenAuthorizationResponseHasDefaultPort443ThenRedirectUriMatchingExcludesPort ( ) throws Exception {
String requestUri = "/login/oauth2/code/" + this . registration2 . getRegistrationId ( ) ;
String state = "state" ;
MockHttpServletRequest request = new MockHttpServletRequest ( "GET" , requestUri ) ;
request . setScheme ( "https" ) ;
request . setServerName ( "example.com" ) ;
request . setServerPort ( 443 ) ;
request . setServletPath ( requestUri ) ;
request . addParameter ( OAuth2ParameterNames . CODE , "code" ) ;
request . addParameter ( OAuth2ParameterNames . STATE , "state" ) ;
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
FilterChain filterChain = mock ( FilterChain . class ) ;
this . setUpAuthorizationRequest ( request , response , this . registration2 , state ) ;
this . setUpAuthenticationResult ( this . registration2 ) ;
this . filter . doFilter ( request , response , filterChain ) ;
ArgumentCaptor < Authentication > authenticationArgCaptor = ArgumentCaptor . forClass ( Authentication . class ) ;
verify ( this . authenticationManager ) . authenticate ( authenticationArgCaptor . capture ( ) ) ;
OAuth2LoginAuthenticationToken authentication = ( OAuth2LoginAuthenticationToken ) authenticationArgCaptor . getValue ( ) ;
OAuth2AuthorizationRequest authorizationRequest = authentication . getAuthorizationExchange ( ) . getAuthorizationRequest ( ) ;
OAuth2AuthorizationResponse authorizationResponse = authentication . getAuthorizationExchange ( ) . getAuthorizationResponse ( ) ;
String expectedRedirectUri = "https://example.com/login/oauth2/code/registration-id-2" ;
assertThat ( authorizationRequest . getRedirectUri ( ) ) . isEqualTo ( expectedRedirectUri ) ;
assertThat ( authorizationResponse . getRedirectUri ( ) ) . isEqualTo ( expectedRedirectUri ) ;
}
// gh-5890
@Test
public void doFilterWhenAuthorizationResponseHasNonDefaultPortThenRedirectUriMatchingIncludesPort ( ) throws Exception {
String requestUri = "/login/oauth2/code/" + this . registration2 . getRegistrationId ( ) ;
String state = "state" ;
MockHttpServletRequest request = new MockHttpServletRequest ( "GET" , requestUri ) ;
request . setScheme ( "https" ) ;
request . setServerName ( "example.com" ) ;
request . setServerPort ( 9090 ) ;
request . setServletPath ( requestUri ) ;
request . addParameter ( OAuth2ParameterNames . CODE , "code" ) ;
request . addParameter ( OAuth2ParameterNames . STATE , "state" ) ;
MockHttpServletResponse response = new MockHttpServletResponse ( ) ;
FilterChain filterChain = mock ( FilterChain . class ) ;
this . setUpAuthorizationRequest ( request , response , this . registration2 , state ) ;
this . setUpAuthenticationResult ( this . registration2 ) ;
this . filter . doFilter ( request , response , filterChain ) ;
ArgumentCaptor < Authentication > authenticationArgCaptor = ArgumentCaptor . forClass ( Authentication . class ) ;
verify ( this . authenticationManager ) . authenticate ( authenticationArgCaptor . capture ( ) ) ;
OAuth2LoginAuthenticationToken authentication = ( OAuth2LoginAuthenticationToken ) authenticationArgCaptor . getValue ( ) ;
OAuth2AuthorizationRequest authorizationRequest = authentication . getAuthorizationExchange ( ) . getAuthorizationRequest ( ) ;
OAuth2AuthorizationResponse authorizationResponse = authentication . getAuthorizationExchange ( ) . getAuthorizationResponse ( ) ;
String expectedRedirectUri = "https://example.com:9090/login/oauth2/code/registration-id-2" ;
assertThat ( authorizationRequest . getRedirectUri ( ) ) . isEqualTo ( expectedRedirectUri ) ;
assertThat ( authorizationResponse . getRedirectUri ( ) ) . isEqualTo ( expectedRedirectUri ) ;
}
private void setUpAuthorizationRequest ( HttpServletRequest request , HttpServletResponse response ,
ClientRegistration registration , String state ) {
OAuth2AuthorizationRequest authorizationRequest = mock ( OAuth2AuthorizationRequest . class ) ;
when ( authorizationRequest . getState ( ) ) . thenReturn ( state ) ;
Map < String , Object > additionalParameters = new HashMap < > ( ) ;
additionalParameters . put ( OAuth2ParameterNames . REGISTRATION_ID , registration . getRegistrationId ( ) ) ;
when ( authorizationRequest . getAdditionalParameters ( ) ) . thenReturn ( additionalParameters ) ;
OAuth2AuthorizationRequest authorizationRequest = OAuth2AuthorizationRequest . authorizationCode ( )
. authorizationUri ( registration . getProviderDetails ( ) . getAuthorizationUri ( ) )
. clientId ( registration . getClientId ( ) )
. redirectUri ( expandRedirectUri ( request , registration ) )
. scopes ( registration . getScopes ( ) )
. state ( state )
. additionalParameters ( additionalParameters )
. build ( ) ;
this . authorizationRequestRepository . saveAuthorizationRequest ( authorizationRequest , request , response ) ;
}
private String expandRedirectUri ( HttpServletRequest request , ClientRegistration clientRegistration ) {
String baseUrl = UriComponentsBuilder . fromHttpUrl ( UrlUtils . buildFullRequestUrl ( request ) )
. replaceQuery ( null )
. replacePath ( request . getContextPath ( ) )
. build ( )
. toUriString ( ) ;
Map < String , String > uriVariables = new HashMap < > ( ) ;
uriVariables . put ( "baseUrl" , baseUrl ) ;
uriVariables . put ( "action" , "login" ) ;
uriVariables . put ( "registrationId" , clientRegistration . getRegistrationId ( ) ) ;
return UriComponentsBuilder . fromUriString ( clientRegistration . getRedirectUriTemplate ( ) )
. buildAndExpand ( uriVariables )
. toUriString ( ) ;
}
private void setUpAuthenticationResult ( ClientRegistration registration ) {
OAuth2User user = mock ( OAuth2User . class ) ;
when ( user . getName ( ) ) . thenReturn ( this . principalName1 ) ;