diff --git a/config/src/test/java/org/springframework/security/config/http/HttpSecurityBeanDefinitionParserTests.java b/config/src/test/java/org/springframework/security/config/http/HttpSecurityBeanDefinitionParserTests.java index 5a5cfacb96..fbb65be712 100644 --- a/config/src/test/java/org/springframework/security/config/http/HttpSecurityBeanDefinitionParserTests.java +++ b/config/src/test/java/org/springframework/security/config/http/HttpSecurityBeanDefinitionParserTests.java @@ -8,12 +8,15 @@ import static org.springframework.security.config.http.AuthenticationConfigBuild import java.lang.reflect.Method; import java.util.ArrayList; import java.util.Collection; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.regex.Pattern; import javax.servlet.Filter; +import javax.servlet.http.HttpServletRequest; import org.junit.After; import org.junit.Test; @@ -39,6 +42,9 @@ import org.springframework.security.openid.OpenID4JavaConsumer; import org.springframework.security.openid.OpenIDAttribute; import org.springframework.security.openid.OpenIDAuthenticationFilter; import org.springframework.security.openid.OpenIDAuthenticationProvider; +import org.springframework.security.openid.OpenIDAuthenticationToken; +import org.springframework.security.openid.OpenIDConsumer; +import org.springframework.security.openid.OpenIDConsumerException; import org.springframework.security.util.FieldUtils; import org.springframework.security.web.FilterChainProxy; import org.springframework.security.web.FilterInvocation; @@ -63,6 +69,7 @@ import org.springframework.security.web.authentication.logout.LogoutHandler; import org.springframework.security.web.authentication.logout.LogoutSuccessHandler; import org.springframework.security.web.authentication.preauth.x509.SubjectDnX509PrincipalExtractor; import org.springframework.security.web.authentication.preauth.x509.X509AuthenticationFilter; +import org.springframework.security.web.authentication.rememberme.AbstractRememberMeServices; import org.springframework.security.web.authentication.rememberme.InMemoryTokenRepositoryImpl; import org.springframework.security.web.authentication.rememberme.PersistentTokenBasedRememberMeServices; import org.springframework.security.web.authentication.rememberme.RememberMeAuthenticationFilter; @@ -1070,6 +1077,55 @@ public class HttpSecurityBeanDefinitionParserTests { getFilter(DefaultLoginPageGeneratingFilter.class); } + @Test + public void openIDAndRememberMeWorkTogether() throws Exception { + setContext( + "" + + " " + + " " + + " " + + "" + + AUTH_PROVIDER_XML); + // Default login filter should be present since we haven't specified any login URLs + DefaultLoginPageGeneratingFilter loginFilter = getFilter(DefaultLoginPageGeneratingFilter.class); + OpenIDAuthenticationFilter openIDFilter = getFilter(OpenIDAuthenticationFilter.class); + openIDFilter.setConsumer(new OpenIDConsumer() { + public String beginConsumption(HttpServletRequest req, String claimedIdentity, String returnToUrl, String realm) + throws OpenIDConsumerException { + return "http://testopenid.com?openid.return_to=" + returnToUrl; + } + + public OpenIDAuthenticationToken endConsumption(HttpServletRequest req) throws OpenIDConsumerException { + throw new UnsupportedOperationException(); + } + }); + Set returnToUrlParameters = new HashSet(); + returnToUrlParameters.add(AbstractRememberMeServices.DEFAULT_PARAMETER); + openIDFilter.setReturnToUrlParameters(returnToUrlParameters); + assertNotNull(FieldUtils.getFieldValue(loginFilter, "openIDrememberMeParameter")); + MockHttpServletRequest request = new MockHttpServletRequest(); + MockHttpServletResponse response = new MockHttpServletResponse(); + + FilterChainProxy fcp = (FilterChainProxy) appContext.getBean(BeanIds.FILTER_CHAIN_PROXY); + request.setServletPath("/something.html"); + fcp.doFilter(request, response, new MockFilterChain()); + assertTrue(response.getRedirectedUrl().endsWith("/spring_security_login")); + request.setServletPath("/spring_security_login"); + request.setRequestURI("/spring_security_login"); + response = new MockHttpServletResponse(); + fcp.doFilter(request, response, new MockFilterChain()); + assertTrue(response.getContentAsString().contains(AbstractRememberMeServices.DEFAULT_PARAMETER)); + request.setRequestURI("/j_spring_openid_security_check"); + request.setParameter(OpenIDAuthenticationFilter.DEFAULT_CLAIMED_IDENTITY_FIELD, "http://hey.openid.com/"); + request.setParameter(AbstractRememberMeServices.DEFAULT_PARAMETER, "on"); + response = new MockHttpServletResponse(); + fcp.doFilter(request, response, new MockFilterChain()); + String expectedReturnTo = request.getRequestURL().append("?") + .append(AbstractRememberMeServices.DEFAULT_PARAMETER) + .append("=").append("on").toString(); + assertEquals("http://testopenid.com?openid.return_to=" + expectedReturnTo, response.getRedirectedUrl()); + } + @Test public void formLoginEntryPointTakesPrecedenceIfLoginUrlIsSet() throws Exception { setContext( diff --git a/openid/src/main/java/org/springframework/security/openid/OpenIDAuthenticationFilter.java b/openid/src/main/java/org/springframework/security/openid/OpenIDAuthenticationFilter.java index 3f4664d6e9..e626888fac 100644 --- a/openid/src/main/java/org/springframework/security/openid/OpenIDAuthenticationFilter.java +++ b/openid/src/main/java/org/springframework/security/openid/OpenIDAuthenticationFilter.java @@ -19,7 +19,10 @@ import java.io.IOException; import java.net.MalformedURLException; import java.net.URL; import java.util.Collections; +import java.util.HashSet; +import java.util.Iterator; import java.util.Map; +import java.util.Set; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -31,6 +34,8 @@ import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.security.web.authentication.AbstractAuthenticationProcessingFilter; import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter; +import org.springframework.security.web.authentication.rememberme.AbstractRememberMeServices; +import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -72,6 +77,7 @@ public class OpenIDAuthenticationFilter extends AbstractAuthenticationProcessing private OpenIDConsumer consumer; private String claimedIdentityFieldName = DEFAULT_CLAIMED_IDENTITY_FIELD; private Map realmMapping = Collections.emptyMap(); + private Set returnToUrlParameters = Collections.emptySet(); //~ Constructors =================================================================================================== @@ -84,6 +90,7 @@ public class OpenIDAuthenticationFilter extends AbstractAuthenticationProcessing @Override public void afterPropertiesSet() { super.afterPropertiesSet(); + if (consumer == null) { try { consumer = new OpenID4JavaConsumer(); @@ -91,6 +98,12 @@ public class OpenIDAuthenticationFilter extends AbstractAuthenticationProcessing throw new IllegalArgumentException("Failed to initialize OpenID", e); } } + + if (returnToUrlParameters.isEmpty() && + getRememberMeServices() instanceof AbstractRememberMeServices) { + returnToUrlParameters = new HashSet(); + returnToUrlParameters.add(((AbstractRememberMeServices)getRememberMeServices()).getParameter()); + } } /** @@ -194,7 +207,32 @@ public class OpenIDAuthenticationFilter extends AbstractAuthenticationProcessing * @return The return_to URL. */ protected String buildReturnToUrl(HttpServletRequest request) { - return request.getRequestURL().toString(); + StringBuffer sb = request.getRequestURL(); + + Iterator iterator = returnToUrlParameters.iterator(); + boolean isFirst = true; + + while (iterator.hasNext()) { + String name = iterator.next(); + // Assume for simplicity that there is only one value + String value = request.getParameter(name); + + if (value == null) { + continue; + } + + if (isFirst) { + sb.append("?"); + isFirst = false; + } + sb.append(name).append("=").append(value); + + if (iterator.hasNext()) { + sb.append("&"); + } + } + + return sb.toString(); } /** @@ -232,4 +270,17 @@ public class OpenIDAuthenticationFilter extends AbstractAuthenticationProcessing public void setConsumer(OpenIDConsumer consumer) { this.consumer = consumer; } + + /** + * Specifies any extra parameters submitted along with the identity field which should be appended to the + * {@literal return_to} URL which is assembled by {@link #buildReturnToUrl}. + * + * @param returnToUrlParameters + * the set of parameter names. If not set, it will default to the parameter name used by the + * {@code RememberMeServices} obtained from the parent class (if one is set). + */ + public void setReturnToUrlParameters(Set returnToUrlParameters) { + Assert.notNull(returnToUrlParameters, "returnToUrlParameters cannot be null"); + this.returnToUrlParameters = returnToUrlParameters; + } } diff --git a/samples/openid/src/main/webapp/WEB-INF/applicationContext-security.xml b/samples/openid/src/main/webapp/WEB-INF/applicationContext-security.xml index 4adbf90c69..fbb072e491 100644 --- a/samples/openid/src/main/webapp/WEB-INF/applicationContext-security.xml +++ b/samples/openid/src/main/webapp/WEB-INF/applicationContext-security.xml @@ -1,8 +1,7 @@ + + + diff --git a/samples/openid/src/main/webapp/openidlogin.jsp b/samples/openid/src/main/webapp/openidlogin.jsp index 02aeebe4d0..33a70df4e5 100644 --- a/samples/openid/src/main/webapp/openidlogin.jsp +++ b/samples/openid/src/main/webapp/openidlogin.jsp @@ -22,7 +22,7 @@ OpenID Identity: - + Remember me on this computer. diff --git a/web/src/main/java/org/springframework/security/web/authentication/rememberme/AbstractRememberMeServices.java b/web/src/main/java/org/springframework/security/web/authentication/rememberme/AbstractRememberMeServices.java index 9a2a016d34..77973f628a 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/rememberme/AbstractRememberMeServices.java +++ b/web/src/main/java/org/springframework/security/web/authentication/rememberme/AbstractRememberMeServices.java @@ -165,7 +165,17 @@ public abstract class AbstractRememberMeServices implements RememberMeServices, String cookieAsPlainText = new String(Base64.decode(cookieValue.getBytes())); - return StringUtils.delimitedListToStringArray(cookieAsPlainText, DELIMITER); + String[] tokens = StringUtils.delimitedListToStringArray(cookieAsPlainText, DELIMITER); + + if (tokens[0].equalsIgnoreCase("http") && tokens[1].startsWith("//")) { + // Assume we've accidentally split a URL (OpenID identifier) + String[] newTokens = new String[tokens.length - 1]; + newTokens[0] = "http:" + tokens[1]; + System.arraycopy(tokens, 2, newTokens, 1, newTokens.length - 1); + tokens = newTokens; + } + + return tokens; } /** diff --git a/web/src/main/java/org/springframework/security/web/authentication/ui/DefaultLoginPageGeneratingFilter.java b/web/src/main/java/org/springframework/security/web/authentication/ui/DefaultLoginPageGeneratingFilter.java index a6bb07319c..50a5b2b1ca 100644 --- a/web/src/main/java/org/springframework/security/web/authentication/ui/DefaultLoginPageGeneratingFilter.java +++ b/web/src/main/java/org/springframework/security/web/authentication/ui/DefaultLoginPageGeneratingFilter.java @@ -147,7 +147,7 @@ public class DefaultLoginPageGeneratingFilter extends GenericFilterBean { sb.append(" Identity:\n"); - if (rememberMeParameter != null) { + if (openIDrememberMeParameter != null) { sb.append(" Remember me on this computer.\n"); } diff --git a/web/src/test/java/org/springframework/security/web/authentication/rememberme/AbstractRememberMeServicesTests.java b/web/src/test/java/org/springframework/security/web/authentication/rememberme/AbstractRememberMeServicesTests.java index 0c80c3a7f6..5c96845976 100644 --- a/web/src/test/java/org/springframework/security/web/authentication/rememberme/AbstractRememberMeServicesTests.java +++ b/web/src/test/java/org/springframework/security/web/authentication/rememberme/AbstractRememberMeServicesTests.java @@ -35,7 +35,7 @@ public class AbstractRememberMeServicesTests { @Test public void cookieShouldBeCorrectlyEncodedAndDecoded() { - String[] cookie = new String[] {"the", "cookie", "tokens", "blah"}; + String[] cookie = new String[] {"http://name", "cookie", "tokens", "blah"}; MockRememberMeServices services = new MockRememberMeServices(); String encoded = services.encodeCookie(cookie); @@ -44,7 +44,7 @@ public class AbstractRememberMeServicesTests { String[] decoded = services.decodeCookie(encoded); assertEquals(4, decoded.length); - assertEquals("the", decoded[0]); + assertEquals("http://name", decoded[0]); assertEquals("cookie", decoded[1]); assertEquals("tokens", decoded[2]); assertEquals("blah", decoded[3]);