@ -15,7 +15,6 @@
@@ -15,7 +15,6 @@
* /
package org.springframework.security.config.annotation.web.configurers.oauth2.client ;
import org.springframework.beans.factory.BeanFactoryUtils ;
import org.springframework.context.ApplicationContext ;
import org.springframework.core.ResolvableType ;
import org.springframework.security.config.annotation.web.HttpSecurityBuilder ;
@ -39,15 +38,12 @@ import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
@@ -39,15 +38,12 @@ import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher ;
import org.springframework.security.web.util.matcher.RequestVariablesExtractor ;
import org.springframework.util.Assert ;
import org.springframework.util.CollectionUtils ;
import java.net.URI ;
import java.util.ArrayList ;
import java.util.Arrays ;
import java.util.Collection ;
import java.util.List ;
import java.util.HashMap ;
import java.util.Map ;
import java.util.stream.Collectors ;
import java.util.stream.Stream ;
import static org.springframework.security.oauth2.client.web.AuthorizationCodeRequestRedirectFilter.REGISTRATION_ID_URI_VARIABLE_NAME ;
@ -75,7 +71,6 @@ public final class OAuth2LoginConfigurer<H extends HttpSecurityBuilder<H>> exten
@@ -75,7 +71,6 @@ public final class OAuth2LoginConfigurer<H extends HttpSecurityBuilder<H>> exten
public OAuth2LoginConfigurer < H > clients ( ClientRegistration . . . clientRegistrations ) {
Assert . notEmpty ( clientRegistrations , "clientRegistrations cannot be empty" ) ;
this . getBuilder ( ) . setSharedObject ( ClientRegistration [ ] . class , clientRegistrations ) ;
return this . clients ( new InMemoryClientRegistrationRepository ( Arrays . asList ( clientRegistrations ) ) ) ;
}
@ -230,56 +225,24 @@ public final class OAuth2LoginConfigurer<H extends HttpSecurityBuilder<H>> exten
@@ -230,56 +225,24 @@ public final class OAuth2LoginConfigurer<H extends HttpSecurityBuilder<H>> exten
}
private static < H extends HttpSecurityBuilder < H > > ClientRegistrationRepository getDefaultClientRegistrationRepository ( H http ) {
List < ClientRegistration > clientRegistrations = getClientRegistrations ( http ) ;
if ( ! CollectionUtils . isEmpty ( clientRegistrations ) ) {
return new InMemoryClientRegistrationRepository ( clientRegistrations ) ;
}
return http . getSharedObject ( ApplicationContext . class ) . getBean ( ClientRegistrationRepository . class ) ;
}
private static < H extends HttpSecurityBuilder < H > > List < ClientRegistration > getClientRegistrations ( H http ) {
ClientRegistration [ ] clientRegistrations = http . getSharedObject ( ClientRegistration [ ] . class ) ;
if ( clientRegistrations ! = null ) {
return Arrays . asList ( clientRegistrations ) ;
}
List < ClientRegistration > clientRegistrationsList = new ArrayList < > ( ) ;
// Check context for type -> Collection<ClientRegistration>
ResolvableType clientRegistrationsType = ResolvableType . forClassWithGenerics (
Collection . class , ClientRegistration . class ) ;
Map < String , ? > clientRegistrationsMap = BeanFactoryUtils . beansOfTypeIncludingAncestors (
http . getSharedObject ( ApplicationContext . class ) ,
clientRegistrationsType . resolve ( Collection . class ) ) ;
clientRegistrationsMap . values ( ) . stream ( )
. filter ( col - > Collection . class . isAssignableFrom ( col . getClass ( ) ) )
. filter ( col - > ( ( Collection ) col ) . stream ( )
. anyMatch ( e - > ClientRegistration . class . isAssignableFrom ( e . getClass ( ) ) ) )
. flatMap ( col - > ( ( Collection ) col ) . stream ( ) )
. forEach ( e - > clientRegistrationsList . add ( ( ClientRegistration ) e ) ) ;
if ( ! clientRegistrationsList . isEmpty ( ) ) {
return clientRegistrationsList ;
}
// Check context for type -> ClientRegistration[]
clientRegistrationsType = ResolvableType . forClass ( ClientRegistration [ ] . class ) ;
clientRegistrationsMap = BeanFactoryUtils . beansOfTypeIncludingAncestors (
http . getSharedObject ( ApplicationContext . class ) ,
clientRegistrationsType . resolve ( ClientRegistration [ ] . class ) ) ;
clientRegistrationsMap . values ( ) . stream ( )
. flatMap ( array - > Arrays . stream ( ( ClientRegistration [ ] ) array ) )
. forEach ( clientRegistrationsList : : add ) ;
return clientRegistrationsList ;
}
private void initDefaultLoginFilter ( H http ) {
DefaultLoginPageGeneratingFilter loginPageGeneratingFilter = http . getSharedObject ( DefaultLoginPageGeneratingFilter . class ) ;
if ( loginPageGeneratingFilter = = null | | this . authorizationCodeAuthenticationFilterConfigurer . isCustomLoginPage ( ) ) {
return ;
}
List < ClientRegistration > clientRegistrations = getClientRegistrations ( http ) ;
if ( CollectionUtils . isEmpty ( clientRegistrations ) ) {
Iterable < ClientRegistration > clientRegistrations = null ;
ClientRegistrationRepository clientRegistrationRepository = getClientRegistrationRepository ( http ) ;
ResolvableType type = ResolvableType . forInstance ( clientRegistrationRepository ) . as ( Iterable . class ) ;
if ( type ! = ResolvableType . NONE ) {
if ( Stream . of ( type . resolveGenerics ( ) ) . anyMatch ( ClientRegistration . class : : isAssignableFrom ) ) {
clientRegistrations = ( Iterable < ClientRegistration > ) clientRegistrationRepository ;
}
}
if ( clientRegistrations = = null ) {
return ;
}
@ -298,10 +261,9 @@ public final class OAuth2LoginConfigurer<H extends HttpSecurityBuilder<H>> exten
@@ -298,10 +261,9 @@ public final class OAuth2LoginConfigurer<H extends HttpSecurityBuilder<H>> exten
authorizationRequestBaseUri = AuthorizationCodeRequestRedirectFilter . DEFAULT_AUTHORIZATION_REQUEST_BASE_URI ;
}
Map < String , String > oauth2AuthenticationUrlToClientName = clientRegistrations . stream ( )
. collect ( Collectors . toMap (
e - > authorizationRequestBaseUri + "/" + e . getRegistrationId ( ) ,
e - > e . getClientName ( ) ) ) ;
Map < String , String > oauth2AuthenticationUrlToClientName = new HashMap < > ( ) ;
clientRegistrations . forEach ( registration - > oauth2AuthenticationUrlToClientName . put (
authorizationRequestBaseUri + "/" + registration . getRegistrationId ( ) , registration . getClientName ( ) ) ) ;
loginPageGeneratingFilter . setOauth2LoginEnabled ( true ) ;
loginPageGeneratingFilter . setOauth2AuthenticationUrlToClientName ( oauth2AuthenticationUrlToClientName ) ;
loginPageGeneratingFilter . setLoginPageUrl ( this . authorizationCodeAuthenticationFilterConfigurer . getLoginUrl ( ) ) ;