diff --git a/config/src/main/java/org/springframework/security/config/web/server/SecurityWebFiltersOrder.java b/config/src/main/java/org/springframework/security/config/web/server/SecurityWebFiltersOrder.java index ce4e349a97..3d2e6be3ac 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/SecurityWebFiltersOrder.java +++ b/config/src/main/java/org/springframework/security/config/web/server/SecurityWebFiltersOrder.java @@ -46,6 +46,10 @@ public enum SecurityWebFiltersOrder { * {@link org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter} */ SECURITY_CONTEXT_SERVER_WEB_EXCHANGE, + /** + * {@link org.springframework.security.web.server.savedrequest.ServerRequestCacheWebFilter} + */ + SERVER_REQUEST_CACHE, LOGOUT, EXCEPTION_TRANSLATION, AUTHORIZATION, diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index fea0a0d36c..55a7515a98 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -39,7 +39,6 @@ import org.springframework.security.web.server.authentication.ServerAuthenticati import org.springframework.security.web.server.authentication.ServerAuthenticationFailureHandler; import org.springframework.security.web.server.authentication.ServerAuthenticationSuccessHandler; import org.springframework.security.web.server.authentication.logout.LogoutWebFilter; -import org.springframework.security.web.server.authentication.logout.SecurityContextServerLogoutHandler; import org.springframework.security.web.server.authentication.logout.ServerLogoutHandler; import org.springframework.security.web.server.authentication.logout.ServerLogoutSuccessHandler; import org.springframework.security.web.server.authorization.AuthorizationContext; @@ -47,10 +46,10 @@ import org.springframework.security.web.server.authorization.AuthorizationWebFil import org.springframework.security.web.server.authorization.DelegatingReactiveAuthorizationManager; import org.springframework.security.web.server.authorization.ExceptionTranslationWebFilter; import org.springframework.security.web.server.authorization.ServerAccessDeniedHandler; -import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter; +import org.springframework.security.web.server.context.NoOpServerSecurityContextRepository; import org.springframework.security.web.server.context.ReactorContextWebFilter; +import org.springframework.security.web.server.context.SecurityContextServerWebExchangeWebFilter; import org.springframework.security.web.server.context.ServerSecurityContextRepository; -import org.springframework.security.web.server.context.NoOpServerSecurityContextRepository; import org.springframework.security.web.server.context.WebSessionServerSecurityContextRepository; import org.springframework.security.web.server.csrf.CsrfWebFilter; import org.springframework.security.web.server.csrf.ServerCsrfTokenRepository; @@ -62,6 +61,10 @@ import org.springframework.security.web.server.header.ServerHttpHeadersWriter; import org.springframework.security.web.server.header.StrictTransportSecurityServerHttpHeadersWriter; import org.springframework.security.web.server.header.XFrameOptionsServerHttpHeadersWriter; import org.springframework.security.web.server.header.XXssProtectionServerHttpHeadersWriter; +import org.springframework.security.web.server.savedrequest.NoOpServerRequestCache; +import org.springframework.security.web.server.savedrequest.ServerRequestCache; +import org.springframework.security.web.server.savedrequest.ServerRequestCacheWebFilter; +import org.springframework.security.web.server.savedrequest.WebSessionServerRequestCache; import org.springframework.security.web.server.ui.LoginPageGeneratingWebFilter; import org.springframework.security.web.server.ui.LogoutPageGeneratingWebFilter; import org.springframework.security.web.server.util.matcher.MediaTypeServerWebExchangeMatcher; @@ -102,6 +105,8 @@ public class ServerHttpSecurity { private HttpBasicBuilder httpBasic; + private final RequestCacheBuilder requestCache = new RequestCacheBuilder(); + private FormLoginBuilder formLogin; private LogoutBuilder logout = new LogoutBuilder(); @@ -198,6 +203,10 @@ public class ServerHttpSecurity { return this.logout; } + public RequestCacheBuilder requestCache() { + return this.requestCache; + } + public ServerHttpSecurity authenticationManager(ReactiveAuthenticationManager manager) { this.authenticationManager = manager; return this; @@ -239,6 +248,7 @@ public class ServerHttpSecurity { if(this.logout != null) { this.logout.configure(this); } + this.requestCache.configure(this); this.addFilterAt(new SecurityContextServerWebExchangeWebFilter(), SecurityWebFiltersOrder.SECURITY_CONTEXT_SERVER_WEB_EXCHANGE); if(this.authorizeExchangeBuilder != null) { ServerAuthenticationEntryPoint serverAuthenticationEntryPoint = getServerAuthenticationEntryPoint(); @@ -433,6 +443,35 @@ public class ServerHttpSecurity { private ExceptionHandlingBuilder() {} } + /** + * @author Rob Winch + * @since 5.0 + */ + public class RequestCacheBuilder { + private ServerRequestCache requestCache = new WebSessionServerRequestCache(); + + public RequestCacheBuilder requestCache(ServerRequestCache requestCache) { + Assert.notNull(requestCache, "requestCache cannot be null"); + this.requestCache = requestCache; + return this; + } + + protected void configure(ServerHttpSecurity http) { + http.addFilterAt(new ServerRequestCacheWebFilter(), SecurityWebFiltersOrder.SERVER_REQUEST_CACHE); + } + + public ServerHttpSecurity and() { + return ServerHttpSecurity.this; + } + + public ServerHttpSecurity disable() { + this.requestCache = NoOpServerRequestCache.getInstance(); + return and(); + } + + private RequestCacheBuilder() {} + } + /** * @author Rob Winch * @since 5.0 @@ -489,6 +528,10 @@ public class ServerHttpSecurity { * @since 5.0 */ public class FormLoginBuilder { + private final RedirectServerAuthenticationSuccessHandler defaultSuccessHandler = new RedirectServerAuthenticationSuccessHandler("/"); + + private RedirectServerAuthenticationEntryPoint defaultEntryPoint; + private ReactiveAuthenticationManager authenticationManager; private ServerSecurityContextRepository serverSecurityContextRepository = new WebSessionServerSecurityContextRepository(); @@ -499,7 +542,7 @@ public class ServerHttpSecurity { private ServerAuthenticationFailureHandler serverAuthenticationFailureHandler; - private ServerAuthenticationSuccessHandler serverAuthenticationSuccessHandler = new RedirectServerAuthenticationSuccessHandler("/"); + private ServerAuthenticationSuccessHandler serverAuthenticationSuccessHandler = this.defaultSuccessHandler; public FormLoginBuilder authenticationManager(ReactiveAuthenticationManager authenticationManager) { this.authenticationManager = authenticationManager; @@ -514,7 +557,8 @@ public class ServerHttpSecurity { } public FormLoginBuilder loginPage(String loginPage) { - this.serverAuthenticationEntryPoint = new RedirectServerAuthenticationEntryPoint(loginPage); + this.defaultEntryPoint = new RedirectServerAuthenticationEntryPoint(loginPage); + this.serverAuthenticationEntryPoint = this.defaultEntryPoint; this.requiresAuthenticationMatcher = ServerWebExchangeMatchers.pathMatchers(HttpMethod.POST, loginPage); this.serverAuthenticationFailureHandler = new RedirectServerAuthenticationFailureHandler(loginPage + "?error"); return this; @@ -553,6 +597,13 @@ public class ServerHttpSecurity { if(this.serverAuthenticationEntryPoint == null) { loginPage("/login"); } + if(http.requestCache != null) { + ServerRequestCache requestCache = http.requestCache.requestCache; + this.defaultSuccessHandler.setRequestCache(requestCache); + if(this.defaultEntryPoint != null) { + this.defaultEntryPoint.setRequestCache(requestCache); + } + } MediaTypeServerWebExchangeMatcher htmlMatcher = new MediaTypeServerWebExchangeMatcher( MediaType.TEXT_HTML); htmlMatcher.setIgnoredMediaTypes(Collections.singleton(MediaType.ALL)); diff --git a/config/src/test/java/org/springframework/security/config/web/server/FormLoginTests.java b/config/src/test/java/org/springframework/security/config/web/server/FormLoginTests.java index 1b8b8f6c7d..921bdaf921 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/FormLoginTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/FormLoginTests.java @@ -21,15 +21,9 @@ import org.openqa.selenium.WebDriver; import org.openqa.selenium.WebElement; import org.openqa.selenium.support.FindBy; import org.openqa.selenium.support.PageFactory; -import org.springframework.security.authentication.ReactiveAuthenticationManager; -import org.springframework.security.authentication.UserDetailsRepositoryReactiveAuthenticationManager; import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder; -import org.springframework.security.core.userdetails.MapReactiveUserDetailsService; -import org.springframework.security.core.userdetails.User; -import org.springframework.security.core.userdetails.UserDetails; import org.springframework.security.htmlunit.server.WebTestClientHtmlUnitDriverBuilder; import org.springframework.security.test.web.reactive.server.WebTestClientBuilder; -import org.springframework.security.web.context.SaveContextOnUpdateOrErrorResponseWrapperTests; import org.springframework.security.web.server.SecurityWebFilterChain; import org.springframework.security.web.server.WebFilterChainProxy; import org.springframework.security.web.server.authentication.RedirectServerAuthenticationSuccessHandler; @@ -39,7 +33,6 @@ import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.ResponseBody; import org.springframework.web.server.ServerWebExchange; -import reactor.core.publisher.Mono; import static org.assertj.core.api.Assertions.assertThat; @@ -143,7 +136,7 @@ public class FormLoginTests { .webTestClientSetup(webTestClient) .build(); - DefaultLoginPage loginPage = HomePage.to(driver, DefaultLoginPage.class) + DefaultLoginPage loginPage = DefaultLoginPage.to(driver) .assertAt(); HomePage homePage = loginPage.loginForm() @@ -238,6 +231,11 @@ public class FormLoginTests { return this.loginForm; } + static DefaultLoginPage to(WebDriver driver) { + driver.get("http://localhost/login"); + return PageFactory.initElements(driver, DefaultLoginPage.class); + } + public static class LoginForm { private WebDriver driver; private WebElement username; @@ -347,6 +345,5 @@ public class FormLoginTests { + " \n" + ""; } - } } diff --git a/config/src/test/java/org/springframework/security/config/web/server/RequestCacheTests.java b/config/src/test/java/org/springframework/security/config/web/server/RequestCacheTests.java new file mode 100644 index 0000000000..bf25ed25b2 --- /dev/null +++ b/config/src/test/java/org/springframework/security/config/web/server/RequestCacheTests.java @@ -0,0 +1,147 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.config.web.server; + +import org.junit.Test; +import org.openqa.selenium.WebDriver; +import org.openqa.selenium.WebElement; +import org.openqa.selenium.support.FindBy; +import org.openqa.selenium.support.PageFactory; +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.web.server.MockServerWebExchange; +import org.springframework.security.config.annotation.web.reactive.ServerHttpSecurityConfigurationBuilder; +import org.springframework.security.config.web.server.FormLoginTests.DefaultLoginPage; +import org.springframework.security.config.web.server.FormLoginTests.HomePage; +import org.springframework.security.htmlunit.server.WebTestClientHtmlUnitDriverBuilder; +import org.springframework.security.test.web.reactive.server.WebTestClientBuilder; +import org.springframework.security.web.server.SecurityWebFilterChain; +import org.springframework.security.web.server.WebFilterChainProxy; +import org.springframework.security.web.server.authentication.RedirectServerAuthenticationSuccessHandler; +import org.springframework.security.web.server.csrf.CsrfToken; +import org.springframework.security.web.server.savedrequest.NoOpServerRequestCache; +import org.springframework.stereotype.Controller; +import org.springframework.test.web.reactive.server.WebTestClient; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.ResponseBody; +import org.springframework.web.server.ServerWebExchange; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Rob Winch + * @since 5.0 + */ +public class RequestCacheTests { + private ServerHttpSecurity http = ServerHttpSecurityConfigurationBuilder.httpWithDefaultAuthentication(); + + @Test + public void defaultFormLoginRequestCache() { + SecurityWebFilterChain securityWebFilter = this.http + .authorizeExchange() + .anyExchange().authenticated() + .and() + .formLogin().and() + .build(); + + WebTestClient webTestClient = WebTestClient + .bindToController(new SecuredPageController(), new WebTestClientBuilder.Http200RestController()) + .webFilter(new WebFilterChainProxy(securityWebFilter)) + .build(); + + WebDriver driver = WebTestClientHtmlUnitDriverBuilder + .webTestClientSetup(webTestClient) + .build(); + + DefaultLoginPage loginPage = SecuredPage.to(driver, DefaultLoginPage.class) + .assertAt(); + + SecuredPage securedPage = loginPage.loginForm() + .username("user") + .password("password") + .submit(SecuredPage.class); + + securedPage.assertAt(); + } + + @Test + public void requestCacheNoOp() { + SecurityWebFilterChain securityWebFilter = this.http + .authorizeExchange() + .anyExchange().authenticated() + .and() + .formLogin().and() + .requestCache() + .requestCache(NoOpServerRequestCache.getInstance()) + .and() + .build(); + + WebTestClient webTestClient = WebTestClient + .bindToController(new SecuredPageController(), new WebTestClientBuilder.Http200RestController()) + .webFilter(new WebFilterChainProxy(securityWebFilter)) + .build(); + + WebDriver driver = WebTestClientHtmlUnitDriverBuilder + .webTestClientSetup(webTestClient) + .build(); + + DefaultLoginPage loginPage = SecuredPage.to(driver, DefaultLoginPage.class) + .assertAt(); + + HomePage securedPage = loginPage.loginForm() + .username("user") + .password("password") + .submit(HomePage.class); + + securedPage.assertAt(); + } + + public static class SecuredPage { + private WebDriver driver; + + public SecuredPage(WebDriver driver) { + this.driver = driver; + } + + public void assertAt() { + assertThat(this.driver.getTitle()).isEqualTo("Secured"); + } + + static T to(WebDriver driver, Class page) { + driver.get("http://localhost/secured"); + return PageFactory.initElements(driver, page); + } + } + + @Controller + public static class SecuredPageController { + @ResponseBody + @GetMapping("/secured") + public String login(ServerWebExchange exchange) { + CsrfToken token = exchange.getAttribute(CsrfToken.class.getName()); + return + "\n" + + "\n" + + " \n" + + " Secured\n" + + " \n" + + " \n" + + "

Secured

\n" + + " \n" + + ""; + } + } +} diff --git a/web/src/main/java/org/springframework/security/web/server/authentication/RedirectServerAuthenticationEntryPoint.java b/web/src/main/java/org/springframework/security/web/server/authentication/RedirectServerAuthenticationEntryPoint.java index 6382a6d4a6..3c28400af6 100644 --- a/web/src/main/java/org/springframework/security/web/server/authentication/RedirectServerAuthenticationEntryPoint.java +++ b/web/src/main/java/org/springframework/security/web/server/authentication/RedirectServerAuthenticationEntryPoint.java @@ -20,6 +20,8 @@ import java.net.URI; import org.springframework.security.web.server.DefaultServerRedirectStrategy; import org.springframework.security.web.server.ServerRedirectStrategy; +import org.springframework.security.web.server.savedrequest.ServerRequestCache; +import org.springframework.security.web.server.savedrequest.WebSessionServerRequestCache; import reactor.core.publisher.Mono; import org.springframework.security.core.AuthenticationException; @@ -39,14 +41,22 @@ public class RedirectServerAuthenticationEntryPoint private ServerRedirectStrategy serverRedirectStrategy = new DefaultServerRedirectStrategy(); + private ServerRequestCache requestCache = new WebSessionServerRequestCache(); + public RedirectServerAuthenticationEntryPoint(String location) { Assert.notNull(location, "location cannot be null"); this.location = URI.create(location); } + public void setRequestCache(ServerRequestCache requestCache) { + Assert.notNull(requestCache, "requestCache cannot be null"); + this.requestCache = requestCache; + } + @Override public Mono commence(ServerWebExchange exchange, AuthenticationException e) { - return this.serverRedirectStrategy.sendRedirect(exchange, this.location); + return this.requestCache.saveRequest(exchange) + .then(this.serverRedirectStrategy.sendRedirect(exchange, this.location)); } /** diff --git a/web/src/main/java/org/springframework/security/web/server/authentication/RedirectServerAuthenticationSuccessHandler.java b/web/src/main/java/org/springframework/security/web/server/authentication/RedirectServerAuthenticationSuccessHandler.java index 3d29c8ce87..75d1c483b6 100644 --- a/web/src/main/java/org/springframework/security/web/server/authentication/RedirectServerAuthenticationSuccessHandler.java +++ b/web/src/main/java/org/springframework/security/web/server/authentication/RedirectServerAuthenticationSuccessHandler.java @@ -20,6 +20,8 @@ import org.springframework.security.core.Authentication; import org.springframework.security.web.server.DefaultServerRedirectStrategy; import org.springframework.security.web.server.ServerRedirectStrategy; import org.springframework.security.web.server.WebFilterExchange; +import org.springframework.security.web.server.savedrequest.ServerRequestCache; +import org.springframework.security.web.server.savedrequest.WebSessionServerRequestCache; import org.springframework.util.Assert; import org.springframework.web.server.ServerWebExchange; import reactor.core.publisher.Mono; @@ -36,17 +38,28 @@ public class RedirectServerAuthenticationSuccessHandler private ServerRedirectStrategy serverRedirectStrategy = new DefaultServerRedirectStrategy(); + private ServerRequestCache requestCache = new WebSessionServerRequestCache(); + public RedirectServerAuthenticationSuccessHandler() {} public RedirectServerAuthenticationSuccessHandler(String location) { this.location = URI.create(location); } + public void setRequestCache(ServerRequestCache requestCache) { + Assert.notNull(requestCache, "requestCache cannot be null"); + this.requestCache = requestCache; + } + @Override public Mono onAuthenticationSuccess(WebFilterExchange webFilterExchange, Authentication authentication) { ServerWebExchange exchange = webFilterExchange.getExchange(); - return this.serverRedirectStrategy.sendRedirect(exchange, this.location); + return this.requestCache.getRequest(exchange) + .map(r -> r.getPath().pathWithinApplication().value()) + .map(URI::create) + .defaultIfEmpty(this.location) + .flatMap(location -> this.serverRedirectStrategy.sendRedirect(exchange, location)); } /** diff --git a/web/src/main/java/org/springframework/security/web/server/savedrequest/NoOpServerRequestCache.java b/web/src/main/java/org/springframework/security/web/server/savedrequest/NoOpServerRequestCache.java new file mode 100644 index 0000000000..6872a38506 --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/server/savedrequest/NoOpServerRequestCache.java @@ -0,0 +1,54 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.server.savedrequest; + +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.web.server.ServerWebExchange; +import reactor.core.publisher.Mono; + +/** + * @author Rob Winch + * @since 5.0 + */ +public class NoOpServerRequestCache implements ServerRequestCache { + @Override + public Mono saveRequest(ServerWebExchange exchange) { + return Mono.empty(); + } + + @Override + public Mono getRequest(ServerWebExchange exchange) { + return Mono.empty(); + } + + @Override + public Mono getMatchingRequest( + ServerWebExchange exchange) { + return Mono.empty(); + } + + @Override + public Mono removeRequest(ServerWebExchange exchange) { + return Mono.empty(); + } + + public static NoOpServerRequestCache getInstance() { + return new NoOpServerRequestCache(); + } + + private NoOpServerRequestCache() {} +} diff --git a/web/src/main/java/org/springframework/security/web/server/savedrequest/ServerRequestCache.java b/web/src/main/java/org/springframework/security/web/server/savedrequest/ServerRequestCache.java new file mode 100644 index 0000000000..a4b256d0fd --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/server/savedrequest/ServerRequestCache.java @@ -0,0 +1,64 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.server.savedrequest; + +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.web.server.ServerWebExchange; +import reactor.core.publisher.Mono; + +/** + * Saves a {@link ServerHttpRequest} so it can be "replayed" later. This is useful for + * when a page was requested and authentication is necessary. + * + * @author Rob Winch + * @since 5.0 + */ +public interface ServerRequestCache { + + /** + * Save the {@link ServerHttpRequest} + * @param exchange the exchange to save + * @return Return a {@code Mono} which only replays complete and error signals + * from this {@link Mono}. + */ + Mono saveRequest(ServerWebExchange exchange); + + /** + * Get the saved {@link ServerHttpRequest} + * @param exchange the exchange to obtain the saved {@link ServerHttpRequest} from + * @return the {@link ServerHttpRequest} + */ + Mono getRequest(ServerWebExchange exchange); + + /** + * If the provided {@link ServerWebExchange} matches the saved {@link ServerHttpRequest} + * gets the saved {@link ServerHttpRequest} + * @param exchange the exchange to obtain the request from + * @return the {@link ServerHttpRequest} + */ + Mono getMatchingRequest(ServerWebExchange exchange); + + /** + * If the {@link ServerWebExchange} contains a saved {@link ServerHttpRequest} remove + * and return it. + * + * @param exchange the {@link ServerWebExchange} to obtain and remove the + * {@link ServerHttpRequest} + * @return the {@link ServerHttpRequest} + */ + Mono removeRequest(ServerWebExchange exchange); +} diff --git a/web/src/main/java/org/springframework/security/web/server/savedrequest/ServerRequestCacheWebFilter.java b/web/src/main/java/org/springframework/security/web/server/savedrequest/ServerRequestCacheWebFilter.java new file mode 100644 index 0000000000..6140e9c803 --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/server/savedrequest/ServerRequestCacheWebFilter.java @@ -0,0 +1,47 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.server.savedrequest; + +import org.springframework.util.Assert; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebFilter; +import org.springframework.web.server.WebFilterChain; +import reactor.core.publisher.Mono; + +/** + * A {@link WebFilter} that replays any matching request in {@link ServerRequestCache} + * + * @author Rob Winch + * @since 5.0 + */ +public class ServerRequestCacheWebFilter implements WebFilter { + private ServerRequestCache requestCache = new WebSessionServerRequestCache(); + + @Override + public Mono filter(ServerWebExchange exchange, WebFilterChain chain) { + return this.requestCache.getMatchingRequest(exchange) + .flatMap(r -> this.requestCache.removeRequest(exchange)) + .map(r -> exchange.mutate().request(r).build()) + .defaultIfEmpty(exchange) + .flatMap(e -> chain.filter(e)); + } + + public void setRequestCache(ServerRequestCache requestCache) { + Assert.notNull(requestCache, "requestCache cannot be null"); + this.requestCache = requestCache; + } +} diff --git a/web/src/main/java/org/springframework/security/web/server/savedrequest/WebSessionServerRequestCache.java b/web/src/main/java/org/springframework/security/web/server/savedrequest/WebSessionServerRequestCache.java new file mode 100644 index 0000000000..ced3dd1cdf --- /dev/null +++ b/web/src/main/java/org/springframework/security/web/server/savedrequest/WebSessionServerRequestCache.java @@ -0,0 +1,99 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.server.savedrequest; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.springframework.http.HttpMethod; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; +import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatchers; +import org.springframework.util.Assert; +import org.springframework.web.server.ServerWebExchange; +import org.springframework.web.server.WebSession; +import reactor.core.publisher.Mono; + +import java.net.URI; + +/** + * An implementation of {@link ServerRequestCache} that saves the + * {@link ServerHttpRequest} in the {@link WebSession}. + * + * The current implementation only saves the URL that was requested. + * + * @author Rob Winch + * @since 5.0 + */ +public class WebSessionServerRequestCache implements ServerRequestCache { + private static final String DEFAULT_SAVED_REQUEST_ATTR = "SPRING_SECURITY_SAVED_REQUEST"; + + protected final Log logger = LogFactory.getLog(this.getClass()); + + private String sessionAttrName = DEFAULT_SAVED_REQUEST_ATTR; + + private ServerWebExchangeMatcher saveRequestMatcher = ServerWebExchangeMatchers.pathMatchers( + HttpMethod.GET, "/**"); + + /** + * Sets the matcher to determine if the request should be saved. The default is to match + * on any GET request. + * + * @param saveRequestMatcher + */ + public void setSaveRequestMatcher(ServerWebExchangeMatcher saveRequestMatcher) { + Assert.notNull(saveRequestMatcher, "saveRequestMatcher cannot be null"); + this.saveRequestMatcher = saveRequestMatcher; + } + + @Override + public Mono saveRequest(ServerWebExchange exchange) { + return this.saveRequestMatcher.matches(exchange) + .filter(m -> m.isMatch()) + .flatMap(m -> exchange.getSession()) + .map(WebSession::getAttributes) + .doOnNext(attrs -> attrs.put(this.sessionAttrName, pathInApplication(exchange.getRequest()))) + .then(); + } + + @Override + public Mono getRequest(ServerWebExchange exchange) { + return exchange.getSession() + .flatMap(session -> Mono.justOrEmpty(session.getAttribute(this.sessionAttrName))) + .map(path -> exchange.getRequest().mutate().path(path).build()); + } + + @Override + public Mono getMatchingRequest( + ServerWebExchange exchange) { + return getRequest(exchange) + .filter( request -> pathInApplication(request).equals( + pathInApplication(exchange.getRequest()))); + } + + @Override + public Mono removeRequest(ServerWebExchange exchange) { + return exchange.getSession() + .map(WebSession::getAttributes) + .flatMap(attrs -> Mono.justOrEmpty(attrs.remove(this.sessionAttrName))) + .cast(String.class) + .map(path -> exchange.getRequest().mutate().path(path).build()); + } + + private static String pathInApplication(ServerHttpRequest request) { + return request.getPath().pathWithinApplication().value(); + } +} diff --git a/web/src/test/java/org/springframework/security/web/server/savedrequest/WebSessionServerRequestCacheTests.java b/web/src/test/java/org/springframework/security/web/server/savedrequest/WebSessionServerRequestCacheTests.java new file mode 100644 index 0000000000..0ac1545725 --- /dev/null +++ b/web/src/test/java/org/springframework/security/web/server/savedrequest/WebSessionServerRequestCacheTests.java @@ -0,0 +1,82 @@ +/* + * Copyright 2002-2017 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.security.web.server.savedrequest; + +import org.junit.Test; +import org.springframework.http.server.reactive.ServerHttpRequest; +import org.springframework.mock.http.server.reactive.MockServerHttpRequest; +import org.springframework.mock.web.server.MockServerWebExchange; +import org.springframework.security.web.server.util.matcher.ServerWebExchangeMatcher; + +import static org.assertj.core.api.Assertions.*; + +/** + * @author Rob Winch + * @since 5.0 + */ +public class WebSessionServerRequestCacheTests { + private WebSessionServerRequestCache cache = new WebSessionServerRequestCache(); + + @Test + public void saveRequestGetRequestWhenGetThenFound() { + MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/secured/")); + this.cache.saveRequest(exchange).block(); + + ServerHttpRequest saved = this.cache.getRequest(exchange).block(); + + assertThat(saved.getURI()).isEqualTo(exchange.getRequest().getURI()); + } + + @Test + public void saveRequestGetRequestWhenPostThenNotFound() { + MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.post("/secured/")); + this.cache.saveRequest(exchange).block(); + + assertThat(this.cache.getRequest(exchange).block()).isNull(); + } + + @Test + public void saveRequestGetRequestWhenPostAndCustomMatcherThenFound() { + this.cache.setSaveRequestMatcher(e -> ServerWebExchangeMatcher.MatchResult.match()); + MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.post("/secured/")); + this.cache.saveRequest(exchange).block(); + + ServerHttpRequest saved = this.cache.getRequest(exchange).block(); + + assertThat(saved.getURI()).isEqualTo(exchange.getRequest().getURI()); + } + + @Test + public void saveRequestRemoveRequestWhenThenFound() { + MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/secured/")); + this.cache.saveRequest(exchange).block(); + + ServerHttpRequest saved = this.cache.removeRequest(exchange).block(); + + assertThat(saved.getURI()).isEqualTo(exchange.getRequest().getURI()); + } + + @Test + public void removeRequestGetRequestWhenDefaultThenNotFound() { + MockServerWebExchange exchange = MockServerWebExchange.from(MockServerHttpRequest.get("/secured/")); + this.cache.saveRequest(exchange).block(); + + this.cache.removeRequest(exchange).block(); + + assertThat(this.cache.getRequest(exchange).block()).isNull(); + } +}