@ -16,14 +16,20 @@
@@ -16,14 +16,20 @@
package org.springframework.security.oauth2.client.registration ;
import java.lang.reflect.Field ;
import java.lang.reflect.Modifier ;
import java.util.Arrays ;
import java.util.Collections ;
import java.util.LinkedHashMap ;
import java.util.List ;
import java.util.Map ;
import java.util.Set ;
import java.util.stream.Collectors ;
import java.util.stream.Stream ;
import org.junit.jupiter.api.Test ;
import org.junit.jupiter.params.ParameterizedTest ;
import org.junit.jupiter.params.provider.MethodSource ;
import org.springframework.security.oauth2.core.AuthenticationMethod ;
import org.springframework.security.oauth2.core.AuthorizationGrantType ;
@ -31,6 +37,7 @@ import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
@@ -31,6 +37,7 @@ import org.springframework.security.oauth2.core.ClientAuthenticationMethod;
import static org.assertj.core.api.Assertions.assertThat ;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException ;
import static org.assertj.core.api.Assertions.assertThatIllegalStateException ;
/ * *
* Tests for { @link ClientRegistration } .
@ -776,4 +783,59 @@ public class ClientRegistrationTests {
@@ -776,4 +783,59 @@ public class ClientRegistrationTests {
assertThat ( clientRegistration . getClientSettings ( ) . isRequireProofKey ( ) ) . isFalse ( ) ;
}
// gh-16382
@Test
void buildWhenNewAuthorizationCodeAndPkceThenBuilds ( ) {
ClientSettings pkceEnabled = ClientSettings . builder ( ) . requireProofKey ( true ) . build ( ) ;
ClientRegistration clientRegistration = ClientRegistration . withRegistrationId ( REGISTRATION_ID )
. clientId ( CLIENT_ID )
. clientSettings ( pkceEnabled )
. authorizationGrantType ( new AuthorizationGrantType ( AuthorizationGrantType . AUTHORIZATION_CODE . getValue ( ) ) )
. redirectUri ( REDIRECT_URI )
. authorizationUri ( AUTHORIZATION_URI )
. tokenUri ( TOKEN_URI )
. build ( ) ;
// proof key should be false for passivity
assertThat ( clientRegistration . getClientSettings ( ) . isRequireProofKey ( ) ) . isTrue ( ) ;
}
@ParameterizedTest
@MethodSource ( "invalidPkceGrantTypes" )
void buildWhenInvalidGrantTypeForPkceThenException ( AuthorizationGrantType invalidGrantType ) {
ClientSettings pkceEnabled = ClientSettings . builder ( ) . requireProofKey ( true ) . build ( ) ;
ClientRegistration . Builder builder = ClientRegistration . withRegistrationId ( REGISTRATION_ID )
. clientId ( CLIENT_ID )
. clientSettings ( pkceEnabled )
. authorizationGrantType ( invalidGrantType )
. redirectUri ( REDIRECT_URI )
. authorizationUri ( AUTHORIZATION_URI )
. tokenUri ( TOKEN_URI ) ;
assertThatIllegalStateException ( ) . describedAs (
"clientSettings.isRequireProofKey=true is only valid with authorizationGrantType=AUTHORIZATION_CODE. Got authorizationGrantType={}" ,
invalidGrantType )
. isThrownBy ( builder : : build ) ;
}
static List < AuthorizationGrantType > invalidPkceGrantTypes ( ) {
return Arrays . stream ( AuthorizationGrantType . class . getFields ( ) )
. filter ( ( field ) - > Modifier . isFinal ( field . getModifiers ( ) )
& & field . getType ( ) = = AuthorizationGrantType . class )
. map ( ( field ) - > getStaticValue ( field , AuthorizationGrantType . class ) )
. filter ( ( grantType ) - > grantType ! = AuthorizationGrantType . AUTHORIZATION_CODE )
// ensure works with .equals
. map ( ( grantType ) - > new AuthorizationGrantType ( grantType . getValue ( ) ) )
. collect ( Collectors . toList ( ) ) ;
}
private static < T > T getStaticValue ( Field field , Class < T > clazz ) {
try {
return ( T ) field . get ( null ) ;
}
catch ( IllegalAccessException ex ) {
throw new RuntimeException ( ex ) ;
}
}
}