|
|
|
|
@ -20,10 +20,12 @@ import static org.assertj.core.api.Assertions.assertThat;
@@ -20,10 +20,12 @@ import static org.assertj.core.api.Assertions.assertThat;
|
|
|
|
|
import static org.mockito.BDDMockito.given; |
|
|
|
|
import static org.mockito.ArgumentMatchers.any; |
|
|
|
|
import static org.mockito.Mockito.mock; |
|
|
|
|
import static org.mockito.Mockito.spy; |
|
|
|
|
import static org.mockito.Mockito.verify; |
|
|
|
|
import static org.mockito.Mockito.verifyZeroInteractions; |
|
|
|
|
import static org.mockito.Mockito.when; |
|
|
|
|
import static org.springframework.security.config.Customizer.withDefaults; |
|
|
|
|
import static org.springframework.test.util.ReflectionTestUtils.getField; |
|
|
|
|
|
|
|
|
|
import java.util.Arrays; |
|
|
|
|
import java.util.List; |
|
|
|
|
@ -35,16 +37,20 @@ import org.apache.http.HttpHeaders;
@@ -35,16 +37,20 @@ import org.apache.http.HttpHeaders;
|
|
|
|
|
import org.junit.Before; |
|
|
|
|
import org.junit.Test; |
|
|
|
|
import org.junit.runner.RunWith; |
|
|
|
|
import org.mockito.ArgumentCaptor; |
|
|
|
|
import org.mockito.Mock; |
|
|
|
|
import org.mockito.junit.MockitoJUnitRunner; |
|
|
|
|
|
|
|
|
|
import org.springframework.security.core.Authentication; |
|
|
|
|
import org.springframework.security.oauth2.client.registration.ReactiveClientRegistrationRepository; |
|
|
|
|
import org.springframework.security.oauth2.client.web.server.ServerAuthorizationRequestRepository; |
|
|
|
|
import org.springframework.security.oauth2.client.web.server.authentication.OAuth2LoginAuthenticationWebFilter; |
|
|
|
|
import org.springframework.security.oauth2.core.endpoint.OAuth2AuthorizationRequest; |
|
|
|
|
import org.springframework.security.oauth2.core.endpoint.TestOAuth2AuthorizationRequests; |
|
|
|
|
import org.springframework.security.web.authentication.preauth.x509.X509PrincipalExtractor; |
|
|
|
|
import org.springframework.security.web.server.authentication.ServerX509AuthenticationConverter; |
|
|
|
|
import org.springframework.security.web.server.savedrequest.ServerRequestCache; |
|
|
|
|
import org.springframework.security.web.server.savedrequest.WebSessionServerRequestCache; |
|
|
|
|
import reactor.core.publisher.Mono; |
|
|
|
|
import reactor.test.publisher.TestPublisher; |
|
|
|
|
|
|
|
|
|
@ -64,7 +70,6 @@ import org.springframework.security.web.server.context.WebSessionServerSecurityC
@@ -64,7 +70,6 @@ import org.springframework.security.web.server.context.WebSessionServerSecurityC
|
|
|
|
|
import org.springframework.security.web.server.csrf.CsrfServerLogoutHandler; |
|
|
|
|
import org.springframework.security.web.server.csrf.CsrfWebFilter; |
|
|
|
|
import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository; |
|
|
|
|
import org.springframework.test.util.ReflectionTestUtils; |
|
|
|
|
import org.springframework.test.web.reactive.server.EntityExchangeResult; |
|
|
|
|
import org.springframework.test.web.reactive.server.FluxExchangeResult; |
|
|
|
|
import org.springframework.test.web.reactive.server.WebTestClient; |
|
|
|
|
@ -200,7 +205,7 @@ public class ServerHttpSecurityTests {
@@ -200,7 +205,7 @@ public class ServerHttpSecurityTests {
|
|
|
|
|
.isNotPresent(); |
|
|
|
|
|
|
|
|
|
Optional<ServerLogoutHandler> logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class) |
|
|
|
|
.map(logoutWebFilter -> (ServerLogoutHandler) ReflectionTestUtils.getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler")); |
|
|
|
|
.map(logoutWebFilter -> (ServerLogoutHandler) getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler")); |
|
|
|
|
|
|
|
|
|
assertThat(logoutHandler) |
|
|
|
|
.get() |
|
|
|
|
@ -213,17 +218,17 @@ public class ServerHttpSecurityTests {
@@ -213,17 +218,17 @@ public class ServerHttpSecurityTests {
|
|
|
|
|
|
|
|
|
|
assertThat(getWebFilter(securityWebFilterChain, CsrfWebFilter.class)) |
|
|
|
|
.get() |
|
|
|
|
.extracting(csrfWebFilter -> ReflectionTestUtils.getField(csrfWebFilter, "csrfTokenRepository")) |
|
|
|
|
.extracting(csrfWebFilter -> getField(csrfWebFilter, "csrfTokenRepository")) |
|
|
|
|
.isEqualTo(this.csrfTokenRepository); |
|
|
|
|
|
|
|
|
|
Optional<ServerLogoutHandler> logoutHandler = getWebFilter(securityWebFilterChain, LogoutWebFilter.class) |
|
|
|
|
.map(logoutWebFilter -> (ServerLogoutHandler) ReflectionTestUtils.getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler")); |
|
|
|
|
.map(logoutWebFilter -> (ServerLogoutHandler) getField(logoutWebFilter, LogoutWebFilter.class, "logoutHandler")); |
|
|
|
|
|
|
|
|
|
assertThat(logoutHandler) |
|
|
|
|
.get() |
|
|
|
|
.isExactlyInstanceOf(DelegatingServerLogoutHandler.class) |
|
|
|
|
.extracting(delegatingLogoutHandler -> |
|
|
|
|
((List<ServerLogoutHandler>) ReflectionTestUtils.getField(delegatingLogoutHandler, DelegatingServerLogoutHandler.class, "delegates")).stream() |
|
|
|
|
((List<ServerLogoutHandler>) getField(delegatingLogoutHandler, DelegatingServerLogoutHandler.class, "delegates")).stream() |
|
|
|
|
.map(ServerLogoutHandler::getClass) |
|
|
|
|
.collect(Collectors.toList())) |
|
|
|
|
.isEqualTo(Arrays.asList(SecurityContextServerLogoutHandler.class, CsrfServerLogoutHandler.class)); |
|
|
|
|
@ -479,6 +484,33 @@ public class ServerHttpSecurityTests {
@@ -479,6 +484,33 @@ public class ServerHttpSecurityTests {
|
|
|
|
|
verify(customServerCsrfTokenRepository).loadToken(any()); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
@Test |
|
|
|
|
public void shouldConfigureRequestCacheForOAuth2LoginAuthenticationEntryPointAndSuccessHandler() { |
|
|
|
|
ServerRequestCache requestCache = spy(new WebSessionServerRequestCache()); |
|
|
|
|
ReactiveClientRegistrationRepository clientRegistrationRepository = mock(ReactiveClientRegistrationRepository.class); |
|
|
|
|
|
|
|
|
|
SecurityWebFilterChain securityFilterChain = this.http |
|
|
|
|
.oauth2Login() |
|
|
|
|
.clientRegistrationRepository(clientRegistrationRepository) |
|
|
|
|
.and() |
|
|
|
|
.authorizeExchange().anyExchange().authenticated() |
|
|
|
|
.and() |
|
|
|
|
.requestCache(c -> c.requestCache(requestCache)) |
|
|
|
|
.build(); |
|
|
|
|
|
|
|
|
|
WebTestClient client = WebTestClientBuilder.bindToWebFilters(securityFilterChain).build(); |
|
|
|
|
client.get().uri("/test").exchange(); |
|
|
|
|
ArgumentCaptor<ServerWebExchange> captor = ArgumentCaptor.forClass(ServerWebExchange.class); |
|
|
|
|
verify(requestCache).saveRequest(captor.capture()); |
|
|
|
|
assertThat(captor.getValue().getRequest().getURI().toString()).isEqualTo("/test"); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
OAuth2LoginAuthenticationWebFilter authenticationWebFilter = |
|
|
|
|
getWebFilter(securityFilterChain, OAuth2LoginAuthenticationWebFilter.class).get(); |
|
|
|
|
Object handler = getField(authenticationWebFilter, "authenticationSuccessHandler"); |
|
|
|
|
assertThat(getField(handler, "requestCache")).isSameAs(requestCache); |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
@Test |
|
|
|
|
public void shouldConfigureAuthorizationRequestRepositoryForOAuth2Login() { |
|
|
|
|
ServerAuthorizationRequestRepository<OAuth2AuthorizationRequest> authorizationRequestRepository = mock(ServerAuthorizationRequestRepository.class); |
|
|
|
|
@ -503,7 +535,7 @@ public class ServerHttpSecurityTests {
@@ -503,7 +535,7 @@ public class ServerHttpSecurityTests {
|
|
|
|
|
|
|
|
|
|
private boolean isX509Filter(WebFilter filter) { |
|
|
|
|
try { |
|
|
|
|
Object converter = ReflectionTestUtils.getField(filter, "authenticationConverter"); |
|
|
|
|
Object converter = getField(filter, "authenticationConverter"); |
|
|
|
|
return converter.getClass().isAssignableFrom(ServerX509AuthenticationConverter.class); |
|
|
|
|
} catch (IllegalArgumentException e) { |
|
|
|
|
// field doesn't exist
|
|
|
|
|
|