diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/HttpEntityMethodProcessor.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/HttpEntityMethodProcessor.java index 98ec1ebdfd7..d03c606d801 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/HttpEntityMethodProcessor.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/mvc/method/annotation/HttpEntityMethodProcessor.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,8 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; import org.springframework.core.MethodParameter; import org.springframework.core.ResolvableType; @@ -34,7 +36,9 @@ import org.springframework.http.ResponseEntity; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.http.server.ServletServerHttpResponse; +import org.springframework.ui.ModelMap; import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; import org.springframework.web.HttpMediaTypeNotSupportedException; import org.springframework.web.accept.ContentNegotiationManager; @@ -42,6 +46,8 @@ import org.springframework.web.bind.support.WebDataBinderFactory; import org.springframework.web.context.request.NativeWebRequest; import org.springframework.web.context.request.ServletWebRequest; import org.springframework.web.method.support.ModelAndViewContainer; +import org.springframework.web.servlet.mvc.support.RedirectAttributes; +import org.springframework.web.servlet.support.RequestContextUtils; /** * Resolves {@link HttpEntity} and {@link RequestEntity} method argument values @@ -197,6 +203,12 @@ public class HttpEntityMethodProcessor extends AbstractMessageConverterMethodPro return; } } + else if (returnStatus / 100 == 3) { + String location = outputHeaders.getFirst("location"); + if (location != null) { + saveFlashAttributes(mavContainer, webRequest, location); + } + } } // Try even with null body. ResponseBodyAdvice could get involved. @@ -241,6 +253,20 @@ public class HttpEntityMethodProcessor extends AbstractMessageConverterMethodPro return servletWebRequest.checkNotModified(etag, lastModifiedTimestamp); } + private void saveFlashAttributes(ModelAndViewContainer mav, NativeWebRequest request, String location) { + mav.setRedirectModelScenario(true); + ModelMap model = mav.getModel(); + if (model instanceof RedirectAttributes) { + Map flashAttributes = ((RedirectAttributes) model).getFlashAttributes(); + if (!CollectionUtils.isEmpty(flashAttributes)) { + HttpServletRequest req = request.getNativeRequest(HttpServletRequest.class); + HttpServletResponse res = request.getNativeRequest(HttpServletResponse.class); + RequestContextUtils.getOutputFlashMap(req).putAll(flashAttributes); + RequestContextUtils.saveOutputFlashMap(location, req, res); + } + } + } + @Override protected Class getReturnValueType(Object returnValue, MethodParameter returnType) { if (returnValue != null) { diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/support/RequestContextUtils.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/support/RequestContextUtils.java index a9ac2e84bee..c9c7590f887 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/support/RequestContextUtils.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/support/RequestContextUtils.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2016 the original author or authors. + * Copyright 2002-2017 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,11 +22,14 @@ import java.util.TimeZone; import javax.servlet.ServletContext; import javax.servlet.ServletRequest; import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; import org.springframework.context.i18n.LocaleContext; import org.springframework.context.i18n.TimeZoneAwareLocaleContext; import org.springframework.ui.context.Theme; import org.springframework.ui.context.ThemeSource; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; import org.springframework.web.context.ContextLoader; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.context.support.WebApplicationContextUtils; @@ -36,6 +39,8 @@ import org.springframework.web.servlet.FlashMapManager; import org.springframework.web.servlet.LocaleContextResolver; import org.springframework.web.servlet.LocaleResolver; import org.springframework.web.servlet.ThemeResolver; +import org.springframework.web.util.UriComponents; +import org.springframework.web.util.UriComponentsBuilder; /** * Utility class for easy access to request-specific state which has been @@ -209,9 +214,8 @@ public abstract class RequestContextUtils { } /** - * Return a read-only {@link Map} with "input" flash attributes saved on a - * previous request. - * @param request the current request + * Return read-only "input" flash attributes from request before redirect. + * @param request current request * @return a read-only Map, or {@code null} if not found * @see FlashMap */ @@ -221,23 +225,52 @@ public abstract class RequestContextUtils { } /** - * Return the "output" FlashMap with attributes to save for a subsequent request. - * @param request the current request - * @return a {@link FlashMap} instance (never {@code null} within a DispatcherServlet request) - * @see FlashMap + * Return "output" FlashMap to save attributes for request after redirect. + * @param request current request + * @return a {@link FlashMap} instance, never {@code null} within a + * {@code DispatcherServlet}-handled request */ public static FlashMap getOutputFlashMap(HttpServletRequest request) { return (FlashMap) request.getAttribute(DispatcherServlet.OUTPUT_FLASH_MAP_ATTRIBUTE); } /** - * Return the FlashMapManager instance to save flash attributes with - * before a redirect. + * Return the {@code FlashMapManager} instance to save flash attributes. + *

As of 5.0 the convenience method {@link #saveOutputFlashMap} may be + * used to save the "output" FlashMap. * @param request the current request - * @return a {@link FlashMapManager} instance (never {@code null} within a DispatcherServlet request) + * @return a {@link FlashMapManager} instance, never {@code null} within a + * {@code DispatcherServlet}-handled request */ public static FlashMapManager getFlashMapManager(HttpServletRequest request) { return (FlashMapManager) request.getAttribute(DispatcherServlet.FLASH_MAP_MANAGER_ATTRIBUTE); } + /** + * Convenience method that retrieves the {@link #getOutputFlashMap "output" + * FlashMap}, updates it with the path and query params of the target URL, + * and then saves it using the {@link #getFlashMapManager FlashMapManager}. + * + * @param location the target URL for the redirect + * @param request the current request + * @param response the current response + * @since 5.0 + */ + public static void saveOutputFlashMap(String location, HttpServletRequest request, + HttpServletResponse response) { + + FlashMap flashMap = getOutputFlashMap(request); + if (CollectionUtils.isEmpty(flashMap)) { + return; + } + + UriComponents uriComponents = UriComponentsBuilder.fromUriString(location).build(); + flashMap.setTargetRequestPath(uriComponents.getPath()); + flashMap.addTargetRequestParams(uriComponents.getQueryParams()); + + FlashMapManager manager = getFlashMapManager(request); + Assert.state(manager != null, "No FlashMapManager. Is this a DispatcherServlet handled request?"); + manager.saveOutputFlashMap(flashMap, request, response); + } + } diff --git a/spring-webmvc/src/main/java/org/springframework/web/servlet/view/RedirectView.java b/spring-webmvc/src/main/java/org/springframework/web/servlet/view/RedirectView.java index 0590fee3aec..fa2ef3ffdca 100644 --- a/spring-webmvc/src/main/java/org/springframework/web/servlet/view/RedirectView.java +++ b/spring-webmvc/src/main/java/org/springframework/web/servlet/view/RedirectView.java @@ -33,18 +33,14 @@ import javax.servlet.http.HttpServletResponse; import org.springframework.beans.BeanUtils; import org.springframework.http.HttpStatus; -import org.springframework.util.CollectionUtils; import org.springframework.util.ObjectUtils; import org.springframework.util.StringUtils; import org.springframework.web.context.WebApplicationContext; -import org.springframework.web.servlet.FlashMap; -import org.springframework.web.servlet.FlashMapManager; import org.springframework.web.servlet.HandlerMapping; import org.springframework.web.servlet.SmartView; import org.springframework.web.servlet.View; import org.springframework.web.servlet.support.RequestContextUtils; import org.springframework.web.servlet.support.RequestDataValueProcessor; -import org.springframework.web.util.UriComponents; import org.springframework.web.util.UriComponentsBuilder; import org.springframework.web.util.UriUtils; import org.springframework.web.util.WebUtils; @@ -305,18 +301,10 @@ public class RedirectView extends AbstractUrlBasedView implements SmartView { String targetUrl = createTargetUrl(model, request); targetUrl = updateTargetUrl(targetUrl, model, request, response); - FlashMap flashMap = RequestContextUtils.getOutputFlashMap(request); - if (!CollectionUtils.isEmpty(flashMap)) { - UriComponents uriComponents = UriComponentsBuilder.fromUriString(targetUrl).build(); - flashMap.setTargetRequestPath(uriComponents.getPath()); - flashMap.addTargetRequestParams(uriComponents.getQueryParams()); - FlashMapManager flashMapManager = RequestContextUtils.getFlashMapManager(request); - if (flashMapManager == null) { - throw new IllegalStateException("FlashMapManager not found despite output FlashMap having been set"); - } - flashMapManager.saveOutputFlashMap(flashMap, request, response); - } + // Save flash attributes + RequestContextUtils.saveOutputFlashMap(targetUrl, request, response); + // Redirect sendRedirect(request, response, targetUrl, this.http10Compatible); } diff --git a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ServletAnnotationControllerHandlerMethodTests.java b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ServletAnnotationControllerHandlerMethodTests.java index cd072124dc8..f21a23f58c7 100644 --- a/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ServletAnnotationControllerHandlerMethodTests.java +++ b/spring-webmvc/src/test/java/org/springframework/web/servlet/mvc/method/annotation/ServletAnnotationControllerHandlerMethodTests.java @@ -109,9 +109,11 @@ import org.springframework.web.accept.ContentNegotiationManagerFactoryBean; import org.springframework.web.bind.WebDataBinder; import org.springframework.web.bind.annotation.CookieValue; import org.springframework.web.bind.annotation.ExceptionHandler; +import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.InitBinder; import org.springframework.web.bind.annotation.ModelAttribute; import org.springframework.web.bind.annotation.PathVariable; +import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RequestHeader; import org.springframework.web.bind.annotation.RequestMapping; @@ -1569,6 +1571,32 @@ public class ServletAnnotationControllerHandlerMethodTests extends AbstractServl assertTrue(RequestContextUtils.getOutputFlashMap(request).isEmpty()); } + @Test // SPR-15176 + public void flashAttributesWithResponseEntity() throws Exception { + initServletWithControllers(RedirectAttributesController.class); + + MockHttpServletRequest request = new MockHttpServletRequest("POST", "/messages-response-entity"); + MockHttpServletResponse response = new MockHttpServletResponse(); + HttpSession session = request.getSession(); + + getServlet().service(request, response); + + assertEquals(302, response.getStatus()); + assertEquals("/messages/1?name=value", response.getRedirectedUrl()); + assertEquals("yay!", RequestContextUtils.getOutputFlashMap(request).get("successMessage")); + + // GET after POST + request = new MockHttpServletRequest("GET", "/messages/1"); + request.setQueryString("name=value"); + request.setSession(session); + response = new MockHttpServletResponse(); + getServlet().service(request, response); + + assertEquals(200, response.getStatus()); + assertEquals("Got: yay!", response.getContentAsString()); + assertTrue(RequestContextUtils.getOutputFlashMap(request).isEmpty()); + } + @Test public void prototypeController() throws Exception { initServlet(new ApplicationContextInitializer() { @@ -3215,21 +3243,26 @@ public class ServletAnnotationControllerHandlerMethodTests extends AbstractServl dataBinder.setRequiredFields("name"); } - @RequestMapping(value = "/messages/{id}", method = RequestMethod.GET) + @GetMapping("/messages/{id}") public void message(ModelMap model, Writer writer) throws IOException { writer.write("Got: " + model.get("successMessage")); } - @RequestMapping(value = "/messages", method = RequestMethod.POST) - public String sendMessage(TestBean testBean, BindingResult result, RedirectAttributes redirectAttrs) { + @PostMapping("/messages") + public String sendMessage(TestBean testBean, BindingResult result, RedirectAttributes attributes) { if (result.hasErrors()) { return "messages/new"; } - else { - redirectAttrs.addAttribute("id", "1").addAttribute("name", "value") - .addFlashAttribute("successMessage", "yay!"); - return "redirect:/messages/{id}"; - } + attributes.addAttribute("id", "1").addAttribute("name", "value"); + attributes.addFlashAttribute("successMessage", "yay!"); + return "redirect:/messages/{id}"; + } + + @PostMapping("/messages-response-entity") + public ResponseEntity sendMessage(RedirectAttributes attributes) { + attributes.addFlashAttribute("successMessage", "yay!"); + URI location = URI.create("/messages/1?name=value"); + return ResponseEntity.status(HttpStatus.FOUND).location(location).build(); } }