From db6da77a5f7693a0eae85c43fe1d7527436419df Mon Sep 17 00:00:00 2001 From: Luke Taylor Date: Tue, 10 Aug 2010 17:39:12 +0100 Subject: [PATCH] SEC-1413: Add RedirectStrategy to AbstractRetryEntryPoint. --- .../channel/AbstractRetryEntryPoint.java | 36 +++++++++++++------ .../channel/RetryWithHttpEntryPointTests.java | 18 +++++++--- .../RetryWithHttpsEntryPointTests.java | 4 --- 3 files changed, 39 insertions(+), 19 deletions(-) diff --git a/web/src/main/java/org/springframework/security/web/access/channel/AbstractRetryEntryPoint.java b/web/src/main/java/org/springframework/security/web/access/channel/AbstractRetryEntryPoint.java index de34ad497d..e38bc2b8db 100644 --- a/web/src/main/java/org/springframework/security/web/access/channel/AbstractRetryEntryPoint.java +++ b/web/src/main/java/org/springframework/security/web/access/channel/AbstractRetryEntryPoint.java @@ -1,9 +1,6 @@ package org.springframework.security.web.access.channel; -import org.springframework.security.web.PortMapper; -import org.springframework.security.web.PortMapperImpl; -import org.springframework.security.web.PortResolver; -import org.springframework.security.web.PortResolverImpl; +import org.springframework.security.web.*; import org.springframework.util.Assert; import org.apache.commons.logging.Log; @@ -30,6 +27,8 @@ public abstract class AbstractRetryEntryPoint implements ChannelEntryPoint { /** The standard port for the scheme (80 for http, 443 for https) */ private final int standardPort; + private RedirectStrategy redirectStrategy = new DefaultRedirectStrategy(); + //~ Constructors =================================================================================================== public AbstractRetryEntryPoint(String scheme, int standardPort) { @@ -39,11 +38,11 @@ public abstract class AbstractRetryEntryPoint implements ChannelEntryPoint { //~ Methods ======================================================================================================== - public void commence(HttpServletRequest request, HttpServletResponse res) throws IOException, ServletException { + public void commence(HttpServletRequest request, HttpServletResponse response) throws IOException, ServletException { String queryString = request.getQueryString(); String redirectUrl = request.getRequestURI() + ((queryString == null) ? "" : ("?" + queryString)); - Integer currentPort = new Integer(portResolver.getServerPort(request)); + Integer currentPort = Integer.valueOf(portResolver.getServerPort(request)); Integer redirectPort = getMappedPort(currentPort); if (redirectPort != null) { @@ -56,7 +55,7 @@ public abstract class AbstractRetryEntryPoint implements ChannelEntryPoint { logger.debug("Redirecting to: " + redirectUrl); } - res.sendRedirect(res.encodeRedirectURL(redirectUrl)); + redirectStrategy.sendRedirect(request, response, redirectUrl); } protected abstract Integer getMappedPort(Integer mapFromPort); @@ -65,10 +64,6 @@ public abstract class AbstractRetryEntryPoint implements ChannelEntryPoint { return portMapper; } - protected final PortResolver getPortResolver() { - return portResolver; - } - public void setPortMapper(PortMapper portMapper) { Assert.notNull(portMapper, "portMapper cannot be null"); this.portMapper = portMapper; @@ -78,4 +73,23 @@ public abstract class AbstractRetryEntryPoint implements ChannelEntryPoint { Assert.notNull(portResolver, "portResolver cannot be null"); this.portResolver = portResolver; } + + protected final PortResolver getPortResolver() { + return portResolver; + } + + /** + * Sets the strategy to be used for redirecting to the required channel URL. A {@code DefaultRedirectStrategy} + * instance will be used if not set. + * + * @param redirectStrategy the strategy instance to which the URL will be passed. + */ + public void setRedirectStrategy(RedirectStrategy redirectStrategy) { + Assert.notNull(redirectStrategy, "redirectStrategy cannot be null"); + this.redirectStrategy = redirectStrategy; + } + + protected final RedirectStrategy getRedirectStrategy() { + return redirectStrategy; + } } diff --git a/web/src/test/java/org/springframework/security/web/access/channel/RetryWithHttpEntryPointTests.java b/web/src/test/java/org/springframework/security/web/access/channel/RetryWithHttpEntryPointTests.java index 5fb5b3f2f8..9f6fe90b8d 100644 --- a/web/src/test/java/org/springframework/security/web/access/channel/RetryWithHttpEntryPointTests.java +++ b/web/src/test/java/org/springframework/security/web/access/channel/RetryWithHttpEntryPointTests.java @@ -15,11 +15,16 @@ package org.springframework.security.web.access.channel; +import static org.mockito.Mockito.mock; + import junit.framework.TestCase; import org.springframework.security.MockPortResolver; +import org.springframework.security.web.PortMapper; import org.springframework.security.web.PortMapperImpl; +import org.springframework.security.web.PortResolver; +import org.springframework.security.web.RedirectStrategy; import org.springframework.security.web.access.channel.RetryWithHttpEntryPoint; import org.springframework.mock.web.MockHttpServletRequest; @@ -59,10 +64,15 @@ public class RetryWithHttpEntryPointTests extends TestCase { public void testGettersSetters() { RetryWithHttpEntryPoint ep = new RetryWithHttpEntryPoint(); - ep.setPortMapper(new PortMapperImpl()); - ep.setPortResolver(new MockPortResolver(8080, 8443)); - assertTrue(ep.getPortMapper() != null); - assertTrue(ep.getPortResolver() != null); + PortMapper portMapper = mock(PortMapper.class); + PortResolver portResolver = mock(PortResolver.class); + RedirectStrategy redirector = mock(RedirectStrategy.class); + ep.setPortMapper(portMapper); + ep.setPortResolver(portResolver); + ep.setRedirectStrategy(redirector); + assertSame(portMapper, ep.getPortMapper()); + assertSame(portResolver, ep.getPortResolver()); + assertSame(redirector, ep.getRedirectStrategy()); } public void testNormalOperation() throws Exception { diff --git a/web/src/test/java/org/springframework/security/web/access/channel/RetryWithHttpsEntryPointTests.java b/web/src/test/java/org/springframework/security/web/access/channel/RetryWithHttpsEntryPointTests.java index c49abadada..54d8b679d3 100644 --- a/web/src/test/java/org/springframework/security/web/access/channel/RetryWithHttpsEntryPointTests.java +++ b/web/src/test/java/org/springframework/security/web/access/channel/RetryWithHttpsEntryPointTests.java @@ -37,10 +37,6 @@ import java.util.Map; public class RetryWithHttpsEntryPointTests extends TestCase { //~ Methods ======================================================================================================== - public final void setUp() throws Exception { - super.setUp(); - } - public void testDetectsMissingPortMapper() throws Exception { RetryWithHttpsEntryPoint ep = new RetryWithHttpsEntryPoint();