diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index 188b9d761a..15d05aa732 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -714,11 +714,23 @@ public class ServerHttpSecurity { * Configures JWT Resource Server Support */ public class JwtSpec { + private ReactiveAuthenticationManager authenticationManager; private ReactiveJwtDecoder jwtDecoder; private BearerTokenServerWebExchangeMatcher bearerTokenServerWebExchangeMatcher = new BearerTokenServerWebExchangeMatcher(); + /** + * Configures the {@link ReactiveAuthenticationManager} to use + * @param authenticationManager the authentication manager to use + * @return the {@code JwtSpec} for additional configuration + */ + public JwtSpec authenticationManager(ReactiveAuthenticationManager authenticationManager) { + Assert.notNull(authenticationManager, "authenticationManager cannot be null"); + this.authenticationManager = authenticationManager; + return this; + } + /** * Configures the {@link ReactiveJwtDecoder} to use * @param jwtDecoder the decoder to use @@ -764,9 +776,7 @@ public class ServerHttpSecurity { registerDefaultAuthenticationEntryPoint(http); registerDefaultCsrfOverride(http); - ReactiveJwtDecoder jwtDecoder = getJwtDecoder(); - JwtReactiveAuthenticationManager authenticationManager = new JwtReactiveAuthenticationManager( - jwtDecoder); + ReactiveAuthenticationManager authenticationManager = getAuthenticationManager(); AuthenticationWebFilter oauth2 = new AuthenticationWebFilter(authenticationManager); oauth2.setServerAuthenticationConverter(bearerTokenConverter); oauth2.setAuthenticationFailureHandler(new ServerAuthenticationEntryPointFailureHandler(entryPoint)); @@ -782,6 +792,17 @@ public class ServerHttpSecurity { return this.jwtDecoder; } + private ReactiveAuthenticationManager getAuthenticationManager() { + if (this.authenticationManager != null) { + return this.authenticationManager; + } + + ReactiveJwtDecoder jwtDecoder = getJwtDecoder(); + ReactiveAuthenticationManager authenticationManager = + new JwtReactiveAuthenticationManager(jwtDecoder); + return authenticationManager; + } + private void registerDefaultAccessDeniedHandler(ServerHttpSecurity http) { if ( http.exceptionHandling != null ) { http.defaultAccessDeniedHandlers.add( @@ -794,7 +815,7 @@ public class ServerHttpSecurity { } private void registerDefaultAuthenticationEntryPoint(ServerHttpSecurity http) { - if ( http.exceptionHandling != null ) { + if (http.exceptionHandling != null) { http.defaultEntryPoints.add( new DelegateEntry( this.bearerTokenServerWebExchangeMatcher, diff --git a/config/src/test/java/org/springframework/security/config/web/server/OAuth2ResourceServerSpecTests.java b/config/src/test/java/org/springframework/security/config/web/server/OAuth2ResourceServerSpecTests.java index b3e99f2981..534c4bd866 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/OAuth2ResourceServerSpecTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/OAuth2ResourceServerSpecTests.java @@ -39,8 +39,12 @@ import org.springframework.beans.factory.NoUniqueBeanDefinitionException; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; +import org.springframework.security.authentication.ReactiveAuthenticationManager; import org.springframework.security.config.annotation.web.reactive.EnableWebFluxSecurity; import org.springframework.security.config.test.SpringTestRule; +import org.springframework.security.core.Authentication; +import org.springframework.security.oauth2.core.OAuth2AuthenticationException; +import org.springframework.security.oauth2.core.OAuth2Error; import org.springframework.security.oauth2.jose.jws.JwsAlgorithms; import org.springframework.security.oauth2.jwt.Jwt; import org.springframework.security.oauth2.jwt.ReactiveJwtDecoder; @@ -57,6 +61,7 @@ import org.springframework.web.reactive.config.EnableWebFlux; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatCode; import static org.hamcrest.core.StringStartsWith.startsWith; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -161,6 +166,23 @@ public class OAuth2ResourceServerSpecTests { .expectStatus().isOk(); } + + @Test + public void getWhenUsingCustomAuthenticationManagerThenUsesItAccordingly() { + this.spring.register(CustomAuthenticationManagerConfig.class).autowire(); + + ReactiveAuthenticationManager authenticationManager = this.spring.getContext().getBean( + ReactiveAuthenticationManager.class); + when(authenticationManager.authenticate(any(Authentication.class))) + .thenReturn(Mono.error(new OAuth2AuthenticationException(new OAuth2Error("mock-failure")))); + + this.client.get() + .headers(headers -> headers.setBearerAuth(this.messageReadToken)) + .exchange() + .expectStatus().isUnauthorized() + .expectHeader().value(HttpHeaders.WWW_AUTHENTICATE, startsWith("Bearer error=\"mock-failure\"")); + } + @Test public void postWhenSignedThenReturnsOk() { this.spring.register(PublicKeyConfig.class, RootController.class).autowire(); @@ -343,6 +365,27 @@ public class OAuth2ResourceServerSpecTests { } } + @EnableWebFlux + @EnableWebFluxSecurity + static class CustomAuthenticationManagerConfig { + @Bean + SecurityWebFilterChain springSecurity(ServerHttpSecurity http) { + // @formatter:off + http + .oauth2ResourceServer() + .jwt() + .authenticationManager(authenticationManager()); + // @formatter:on + + return http.build(); + } + + @Bean + ReactiveAuthenticationManager authenticationManager() { + return mock(ReactiveAuthenticationManager.class); + } + } + @RestController static class RootController { @GetMapping @@ -356,7 +399,6 @@ public class OAuth2ResourceServerSpecTests { } } - private static RSAPublicKey publicKey() throws NoSuchAlgorithmException, InvalidKeySpecException { String modulus = "26323220897278656456354815752829448539647589990395639665273015355787577386000316054335559633864476469390247312823732994485311378484154955583861993455004584140858982659817218753831620205191028763754231454775026027780771426040997832758235764611119743390612035457533732596799927628476322029280486807310749948064176545712270582940917249337311592011920620009965129181413510845780806191965771671528886508636605814099711121026468495328702234901200169245493126030184941412539949521815665744267183140084667383643755535107759061065656273783542590997725982989978433493861515415520051342321336460543070448417126615154138673620797"; String exponent = "65537";