diff --git a/web/src/main/java/org/springframework/security/web/context/HttpSessionSecurityContextRepository.java b/web/src/main/java/org/springframework/security/web/context/HttpSessionSecurityContextRepository.java index 2563d1f20e..725c67a5cb 100644 --- a/web/src/main/java/org/springframework/security/web/context/HttpSessionSecurityContextRepository.java +++ b/web/src/main/java/org/springframework/security/web/context/HttpSessionSecurityContextRepository.java @@ -1,5 +1,13 @@ package org.springframework.security.web.context; +import javax.servlet.AsyncContext; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletRequestWrapper; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpSession; + import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.springframework.security.authentication.AuthenticationTrustResolver; @@ -9,11 +17,7 @@ import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolderStrategy; import org.springframework.util.Assert; -import org.springframework.util.ReflectionUtils; - -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import javax.servlet.http.HttpSession; +import org.springframework.util.ClassUtils; /** * A {@code SecurityContextRepository} implementation which stores the security context in the {@code HttpSession} @@ -62,6 +66,7 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo private final Object contextObject = SecurityContextHolder.createEmptyContext(); private boolean allowSessionCreation = true; private boolean disableUrlRewriting = false; + private boolean isServlet3 = ClassUtils.hasMethod(ServletRequest.class, "startAsync"); private String springSecurityContextKey = SPRING_SECURITY_CONTEXT_KEY; private final AuthenticationTrustResolver authenticationTrustResolver = new AuthenticationTrustResolverImpl(); @@ -89,8 +94,12 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo } - requestResponseHolder.setResponse( - new SaveToSessionResponseWrapper(response, request, httpSession != null, context)); + SaveToSessionResponseWrapper wrappedResponse = new SaveToSessionResponseWrapper(response, request, httpSession != null, context); + requestResponseHolder.setResponse(wrappedResponse); + + if(isServlet3) { + requestResponseHolder.setRequest(new Servlet3SaveToSessionRequestWrapper(request, wrappedResponse)); + } return context; } @@ -212,6 +221,28 @@ public class HttpSessionSecurityContextRepository implements SecurityContextRepo //~ Inner Classes ================================================================================================== + private static class Servlet3SaveToSessionRequestWrapper extends HttpServletRequestWrapper { + private final SaveContextOnUpdateOrErrorResponseWrapper response; + + public Servlet3SaveToSessionRequestWrapper(HttpServletRequest request,SaveContextOnUpdateOrErrorResponseWrapper response) { + super(request); + this.response = response; + } + + @Override + public AsyncContext startAsync() { + response.disableSaveOnResponseCommitted(); + return super.startAsync(); + } + + @Override + public AsyncContext startAsync(ServletRequest servletRequest, + ServletResponse servletResponse) throws IllegalStateException { + response.disableSaveOnResponseCommitted(); + return super.startAsync(servletRequest, servletResponse); + } + } + /** * Wrapper that is applied to every request/response to update the HttpSession with * the SecurityContext when a sendError() or sendRedirect diff --git a/web/src/main/java/org/springframework/security/web/context/SaveContextOnUpdateOrErrorResponseWrapper.java b/web/src/main/java/org/springframework/security/web/context/SaveContextOnUpdateOrErrorResponseWrapper.java index e60f3bbad0..0c9983b4bf 100644 --- a/web/src/main/java/org/springframework/security/web/context/SaveContextOnUpdateOrErrorResponseWrapper.java +++ b/web/src/main/java/org/springframework/security/web/context/SaveContextOnUpdateOrErrorResponseWrapper.java @@ -44,7 +44,7 @@ import org.springframework.security.core.context.SecurityContextHolder; public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends HttpServletResponseWrapper { private final Log logger = LogFactory.getLog(getClass()); - private final Thread SUPPORTED_THREAD = Thread.currentThread(); + private boolean disableSaveOnResponseCommitted; private boolean contextSaved = false; /* See SEC-1052 */ @@ -60,6 +60,16 @@ public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends HttpServ this.disableUrlRewriting = disableUrlRewriting; } + /** + * Invoke this method to disable automatic saving of the + * {@link SecurityContext} when the {@link HttpServletResponse} is + * committed. This can be useful in the event that Async Web Requests are + * made which may no longer contain the {@link SecurityContext} on it. + */ + public void disableSaveOnResponseCommitted() { + this.disableSaveOnResponseCommitted = true; + } + /** * Implements the logic for storing the security context. * @@ -126,18 +136,16 @@ public abstract class SaveContextOnUpdateOrErrorResponseWrapper extends HttpServ } /** - * Calls saveContext() with the current contents of the SecurityContextHolder as long as - * {@link #doSaveContext()} is invoked on the same Thread that {@link SaveContextOnUpdateOrErrorResponseWrapper} was - * created on. This prevents issues when dealing with Async Web Requests where the {@link SecurityContext} is not - * present on the Thread that processes the response. + * Calls saveContext() with the current contents of the + * SecurityContextHolder as long as + * {@link #disableSaveOnResponseCommitted()()} was not invoked. */ private void doSaveContext() { - Thread currentThread = Thread.currentThread(); - if(SUPPORTED_THREAD == currentThread) { + if(!disableSaveOnResponseCommitted) { saveContext(SecurityContextHolder.getContext()); contextSaved = true; } else if(logger.isDebugEnabled()){ - logger.debug("Skip saving SecurityContext since processing the HttpServletResponse on a different Thread than the original HttpServletRequest"); + logger.debug("Skip saving SecurityContext since saving on response commited is disabled"); } } diff --git a/web/src/test/java/org/springframework/security/web/context/HttpSessionSecurityContextRepositoryTests.java b/web/src/test/java/org/springframework/security/web/context/HttpSessionSecurityContextRepositoryTests.java index 75f634b975..c3f3628be6 100644 --- a/web/src/test/java/org/springframework/security/web/context/HttpSessionSecurityContextRepositoryTests.java +++ b/web/src/test/java/org/springframework/security/web/context/HttpSessionSecurityContextRepositoryTests.java @@ -12,16 +12,26 @@ */ package org.springframework.security.web.context; +import static org.fest.assertions.Assertions.assertThat; import static org.junit.Assert.*; -import static org.mockito.Mockito.*; import static org.springframework.security.web.context.HttpSessionSecurityContextRepository.*; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.reset; +import static org.mockito.Matchers.*; +import static org.powermock.api.mockito.PowerMockito.*; import javax.servlet.ServletOutputStream; +import javax.servlet.ServletRequest; +import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpSession; import org.junit.After; import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockHttpServletResponse; import org.springframework.security.authentication.AnonymousAuthenticationToken; @@ -29,11 +39,14 @@ import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.util.ClassUtils; /** * @author Luke Taylor * @author Rob Winch */ +@RunWith(PowerMockRunner.class) +@PrepareForTest({ClassUtils.class}) public class HttpSessionSecurityContextRepositoryTests { private final TestingAuthenticationToken testToken = new TestingAuthenticationToken("someone", "passwd", "ROLE_A"); @@ -42,6 +55,53 @@ public class HttpSessionSecurityContextRepositoryTests { SecurityContextHolder.clearContext(); } + @Test + public void servlet25Compatability() throws Exception { + spy(ClassUtils.class); + when(ClassUtils.class,"hasMethod", ServletRequest.class, "startAsync", new Class[] {}).thenReturn(false); + HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository(); + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + HttpRequestResponseHolder holder = new HttpRequestResponseHolder(request, response); + repo.loadContext(holder); + assertThat(holder.getRequest()).isSameAs(request); + } + + @Test + public void startAsyncDisablesSaveOnCommit() throws Exception { + HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository(); + HttpServletRequest request = mock(HttpServletRequest.class); + MockHttpServletResponse response = new MockHttpServletResponse(); + HttpRequestResponseHolder holder = new HttpRequestResponseHolder(request, response); + repo.loadContext(holder); + + reset(request); + holder.getRequest().startAsync(); + holder.getResponse().sendError(HttpServletResponse.SC_BAD_REQUEST); + + // ensure that sendError did cause interaction with the HttpSession + verify(request, never()).getSession(anyBoolean()); + verify(request, never()).getSession(); + } + + + @Test + public void startAsyncRequestResponseDisablesSaveOnCommit() throws Exception { + HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository(); + HttpServletRequest request = mock(HttpServletRequest.class); + MockHttpServletResponse response = new MockHttpServletResponse(); + HttpRequestResponseHolder holder = new HttpRequestResponseHolder(request, response); + repo.loadContext(holder); + + reset(request); + holder.getRequest().startAsync(request,response); + holder.getResponse().sendError(HttpServletResponse.SC_BAD_REQUEST); + + // ensure that sendError did cause interaction with the HttpSession + verify(request, never()).getSession(anyBoolean()); + verify(request, never()).getSession(); + } + @Test public void sessionIsntCreatedIfContextDoesntChange() throws Exception { HttpSessionSecurityContextRepository repo = new HttpSessionSecurityContextRepository(); diff --git a/web/src/test/java/org/springframework/security/web/context/SaveContextOnUpdateOrErrorResponseWrapperTests.java b/web/src/test/java/org/springframework/security/web/context/SaveContextOnUpdateOrErrorResponseWrapperTests.java index 2f39db55ac..f41da7bedd 100644 --- a/web/src/test/java/org/springframework/security/web/context/SaveContextOnUpdateOrErrorResponseWrapperTests.java +++ b/web/src/test/java/org/springframework/security/web/context/SaveContextOnUpdateOrErrorResponseWrapperTests.java @@ -14,8 +14,6 @@ package org.springframework.security.web.context; import static org.fest.assertions.Assertions.assertThat; -import java.io.IOException; - import javax.servlet.http.HttpServletResponse; import org.junit.After; @@ -63,19 +61,10 @@ public class SaveContextOnUpdateOrErrorResponseWrapperTests { } @Test - public void sendErrorSkipsSaveSecurityContextOnNewThread() throws Exception { + public void sendErrorSkipsSaveSecurityContextDisables() throws Exception { final int error = HttpServletResponse.SC_FORBIDDEN; - Thread t = new Thread() { - public void run() { - try { - wrappedResponse.sendError(error); - } catch(IOException e) { - throw new RuntimeException(e); - } - } - }; - t.start(); - t.join(); + wrappedResponse.disableSaveOnResponseCommitted(); + wrappedResponse.sendError(error); assertThat(wrappedResponse.securityContext).isNull(); assertThat(response.getStatus()).isEqualTo(error); } @@ -91,20 +80,11 @@ public class SaveContextOnUpdateOrErrorResponseWrapperTests { } @Test - public void sendErrorWithMessageSkipsSaveSecurityContextOnNewThread() throws Exception { + public void sendErrorWithMessageSkipsSaveSecurityContextDisables() throws Exception { final int error = HttpServletResponse.SC_FORBIDDEN; final String message = "Forbidden"; - Thread t = new Thread() { - public void run() { - try { - wrappedResponse.sendError(error, message); - } catch(IOException e) { - throw new RuntimeException(e); - } - } - }; - t.start(); - t.join(); + wrappedResponse.disableSaveOnResponseCommitted(); + wrappedResponse.sendError(error, message); assertThat(wrappedResponse.securityContext).isNull(); assertThat(response.getStatus()).isEqualTo(error); assertThat(response.getErrorMessage()).isEqualTo(message); @@ -119,19 +99,10 @@ public class SaveContextOnUpdateOrErrorResponseWrapperTests { } @Test - public void sendRedirectSkipsSaveSecurityContextOnNewThread() throws Exception { + public void sendRedirectSkipsSaveSecurityContextDisables() throws Exception { final String url = "/location"; - Thread t = new Thread() { - public void run() { - try { - wrappedResponse.sendRedirect(url); - } catch(IOException e) { - throw new RuntimeException(e); - } - } - }; - t.start(); - t.join(); + wrappedResponse.disableSaveOnResponseCommitted(); + wrappedResponse.sendRedirect(url); assertThat(wrappedResponse.securityContext).isNull(); assertThat(response.getRedirectedUrl()).isEqualTo(url); } @@ -143,18 +114,9 @@ public class SaveContextOnUpdateOrErrorResponseWrapperTests { } @Test - public void outputFlushSkipsSaveSecurityContextOnNewThread() throws Exception { - Thread t = new Thread() { - public void run() { - try { - wrappedResponse.getOutputStream().flush(); - } catch(IOException e) { - throw new RuntimeException(e); - } - } - }; - t.start(); - t.join(); + public void outputFlushSkipsSaveSecurityContextDisables() throws Exception { + wrappedResponse.disableSaveOnResponseCommitted(); + wrappedResponse.getOutputStream().flush(); assertThat(wrappedResponse.securityContext).isNull(); } @@ -165,18 +127,9 @@ public class SaveContextOnUpdateOrErrorResponseWrapperTests { } @Test - public void outputCloseSkipsSaveSecurityContextOnNewThread() throws Exception { - Thread t = new Thread() { - public void run() { - try { - wrappedResponse.getOutputStream().close(); - } catch(IOException e) { - throw new RuntimeException(e); - } - } - }; - t.start(); - t.join(); + public void outputCloseSkipsSaveSecurityContextDisables() throws Exception { + wrappedResponse.disableSaveOnResponseCommitted(); + wrappedResponse.getOutputStream().close(); assertThat(wrappedResponse.securityContext).isNull(); } @@ -187,18 +140,9 @@ public class SaveContextOnUpdateOrErrorResponseWrapperTests { } @Test - public void writerFlushSkipsSaveSecurityContextOnNewThread() throws Exception { - Thread t = new Thread() { - public void run() { - try { - wrappedResponse.getWriter().flush(); - } catch(IOException e) { - throw new RuntimeException(e); - } - } - }; - t.start(); - t.join(); + public void writerFlushSkipsSaveSecurityContextDisables() throws Exception { + wrappedResponse.disableSaveOnResponseCommitted(); + wrappedResponse.getWriter().flush(); assertThat(wrappedResponse.securityContext).isNull(); } @@ -209,18 +153,9 @@ public class SaveContextOnUpdateOrErrorResponseWrapperTests { } @Test - public void writerCloseSkipsSaveSecurityContextOnNewThread() throws Exception { - Thread t = new Thread() { - public void run() { - try { - wrappedResponse.getWriter().close(); - } catch(IOException e) { - throw new RuntimeException(e); - } - } - }; - t.start(); - t.join(); + public void writerCloseSkipsSaveSecurityContextDisables() throws Exception { + wrappedResponse.disableSaveOnResponseCommitted(); + wrappedResponse.getWriter().close(); assertThat(wrappedResponse.securityContext).isNull(); } @@ -231,18 +166,9 @@ public class SaveContextOnUpdateOrErrorResponseWrapperTests { } @Test - public void flushBufferSkipsSaveSecurityContextOnNewThread() throws Exception { - Thread t = new Thread() { - public void run() { - try { - wrappedResponse.flushBuffer(); - } catch(IOException e) { - throw new RuntimeException(e); - } - } - }; - t.start(); - t.join(); + public void flushBufferSkipsSaveSecurityContextDisables() throws Exception { + wrappedResponse.disableSaveOnResponseCommitted(); + wrappedResponse.flushBuffer(); assertThat(wrappedResponse.securityContext).isNull(); }