13 changed files with 872 additions and 6 deletions
@ -0,0 +1,62 @@
@@ -0,0 +1,62 @@
|
||||
/* |
||||
* Copyright 2002-2015 the original author or authors. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||
* you may not use this file except in compliance with the License. |
||||
* You may obtain a copy of the License at |
||||
* |
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
package org.springframework.security.messaging.web.csrf; |
||||
|
||||
import java.util.Map; |
||||
|
||||
import org.springframework.messaging.Message; |
||||
import org.springframework.messaging.MessageChannel; |
||||
import org.springframework.messaging.simp.SimpMessageHeaderAccessor; |
||||
import org.springframework.messaging.simp.SimpMessageType; |
||||
import org.springframework.messaging.support.ChannelInterceptorAdapter; |
||||
import org.springframework.security.messaging.util.matcher.MessageMatcher; |
||||
import org.springframework.security.messaging.util.matcher.SimpMessageTypeMatcher; |
||||
import org.springframework.security.web.csrf.CsrfToken; |
||||
import org.springframework.security.web.csrf.InvalidCsrfTokenException; |
||||
import org.springframework.security.web.csrf.MissingCsrfTokenException; |
||||
|
||||
/** |
||||
* {@link ChannelInterceptorAdapter} that validates that a valid CSRF is included in the header of any |
||||
* {@link SimpMessageType#CONNECT} message. The expected {@link CsrfToken} is populated by CsrfTokenHandshakeInterceptor. |
||||
* |
||||
* @author Rob Winch |
||||
* @since 4.0 |
||||
*/ |
||||
public final class CsrfChannelInterceptor extends ChannelInterceptorAdapter { |
||||
private final MessageMatcher<Object> matcher = new SimpMessageTypeMatcher(SimpMessageType.CONNECT); |
||||
|
||||
@Override |
||||
public Message<?> preSend(Message<?> message, MessageChannel channel) { |
||||
if(!matcher.matches(message)) { |
||||
return message; |
||||
} |
||||
|
||||
Map<String, Object> sessionAttributes = SimpMessageHeaderAccessor.getSessionAttributes(message.getHeaders()); |
||||
CsrfToken expectedToken = sessionAttributes == null ? null : (CsrfToken) sessionAttributes.get(CsrfToken.class.getName()); |
||||
|
||||
if(expectedToken == null) { |
||||
throw new MissingCsrfTokenException(null); |
||||
} |
||||
|
||||
String actualTokenValue = SimpMessageHeaderAccessor.wrap(message).getFirstNativeHeader(expectedToken.getHeaderName()); |
||||
|
||||
boolean csrfCheckPassed = expectedToken.getToken().equals(actualTokenValue); |
||||
if(csrfCheckPassed) { |
||||
return message; |
||||
} |
||||
throw new InvalidCsrfTokenException(expectedToken, actualTokenValue); |
||||
} |
||||
} |
||||
@ -0,0 +1,54 @@
@@ -0,0 +1,54 @@
|
||||
/* |
||||
* Copyright 2002-2015 the original author or authors. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||
* you may not use this file except in compliance with the License. |
||||
* You may obtain a copy of the License at |
||||
* |
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
package org.springframework.security.messaging.web.socket.server; |
||||
|
||||
import java.util.Map; |
||||
|
||||
import javax.servlet.http.HttpServletRequest; |
||||
|
||||
import org.springframework.http.server.ServerHttpRequest; |
||||
import org.springframework.http.server.ServerHttpResponse; |
||||
import org.springframework.http.server.ServletServerHttpRequest; |
||||
import org.springframework.security.web.csrf.CsrfToken; |
||||
import org.springframework.web.socket.WebSocketHandler; |
||||
import org.springframework.web.socket.server.HandshakeInterceptor; |
||||
|
||||
/** |
||||
* Copies a CsrfToken from the HttpServletRequest's attributes to the WebSocket attributes. This is used as the |
||||
* expected CsrfToken when validating connection requests to ensure only the same origin connects. |
||||
* |
||||
* @author Rob Winch |
||||
* @since 4.0 |
||||
*/ |
||||
public final class CsrfTokenHandshakeInterceptor implements HandshakeInterceptor { |
||||
|
||||
public boolean beforeHandshake(ServerHttpRequest request, |
||||
ServerHttpResponse response, WebSocketHandler wsHandler, |
||||
Map<String, Object> attributes) throws Exception { |
||||
HttpServletRequest httpRequest = ((ServletServerHttpRequest)request).getServletRequest(); |
||||
CsrfToken token = (CsrfToken) httpRequest.getAttribute(CsrfToken.class.getName()); |
||||
if(token == null) { |
||||
return true; |
||||
} |
||||
attributes.put(CsrfToken.class.getName(), token); |
||||
return true; |
||||
} |
||||
|
||||
public void afterHandshake(ServerHttpRequest request, |
||||
ServerHttpResponse response, WebSocketHandler wsHandler, |
||||
Exception exception) { |
||||
} |
||||
} |
||||
@ -0,0 +1,154 @@
@@ -0,0 +1,154 @@
|
||||
/* |
||||
* Copyright 2002-2015 the original author or authors. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||
* you may not use this file except in compliance with the License. |
||||
* You may obtain a copy of the License at |
||||
* |
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
package org.springframework.security.messaging.web.csrf; |
||||
|
||||
import java.util.HashMap; |
||||
import java.util.Map; |
||||
|
||||
import org.junit.Before; |
||||
import org.junit.Test; |
||||
import org.junit.runner.RunWith; |
||||
import org.mockito.Mock; |
||||
import org.mockito.runners.MockitoJUnitRunner; |
||||
import org.springframework.messaging.Message; |
||||
import org.springframework.messaging.MessageChannel; |
||||
import org.springframework.messaging.simp.SimpMessageHeaderAccessor; |
||||
import org.springframework.messaging.simp.SimpMessageType; |
||||
import org.springframework.messaging.support.MessageBuilder; |
||||
import org.springframework.security.web.csrf.CsrfToken; |
||||
import org.springframework.security.web.csrf.DefaultCsrfToken; |
||||
import org.springframework.security.web.csrf.InvalidCsrfTokenException; |
||||
import org.springframework.security.web.csrf.MissingCsrfTokenException; |
||||
|
||||
@RunWith(MockitoJUnitRunner.class) |
||||
public class CsrfChannelInterceptorTests { |
||||
@Mock |
||||
MessageChannel channel; |
||||
|
||||
SimpMessageHeaderAccessor messageHeaders; |
||||
|
||||
CsrfToken token; |
||||
|
||||
CsrfChannelInterceptor interceptor; |
||||
|
||||
@Before |
||||
public void setup() { |
||||
token = new DefaultCsrfToken("header", "param", "token"); |
||||
interceptor = new CsrfChannelInterceptor(); |
||||
|
||||
messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT); |
||||
messageHeaders.setNativeHeader(token.getHeaderName(), token.getToken()); |
||||
messageHeaders.setSessionAttributes(new HashMap<String,Object>()); |
||||
messageHeaders.getSessionAttributes().put(CsrfToken.class.getName(), token); |
||||
} |
||||
|
||||
@Test |
||||
public void preSendValidToken() { |
||||
interceptor.preSend(message(), channel); |
||||
} |
||||
|
||||
@Test |
||||
public void preSendIgnoresConnectAck() { |
||||
messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.CONNECT_ACK); |
||||
|
||||
interceptor.preSend(message(), channel); |
||||
} |
||||
|
||||
@Test |
||||
public void preSendIgnoresDisconnect() { |
||||
messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT); |
||||
|
||||
interceptor.preSend(message(), channel); |
||||
} |
||||
|
||||
@Test |
||||
public void preSendIgnoresDisconnectAck() { |
||||
messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.DISCONNECT_ACK); |
||||
|
||||
interceptor.preSend(message(), channel); |
||||
} |
||||
|
||||
@Test |
||||
public void preSendIgnoresHeartbeat() { |
||||
messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.HEARTBEAT); |
||||
|
||||
interceptor.preSend(message(), channel); |
||||
} |
||||
|
||||
@Test |
||||
public void preSendIgnoresMessage() { |
||||
messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.MESSAGE); |
||||
|
||||
interceptor.preSend(message(), channel); |
||||
} |
||||
|
||||
@Test |
||||
public void preSendIgnoresOther() { |
||||
messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.OTHER); |
||||
|
||||
interceptor.preSend(message(), channel); |
||||
} |
||||
|
||||
@Test |
||||
public void preSendIgnoresSubscribe() { |
||||
messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.SUBSCRIBE); |
||||
|
||||
interceptor.preSend(message(), channel); |
||||
} |
||||
|
||||
@Test |
||||
public void preSendIgnoresUnsubscribe() { |
||||
messageHeaders = SimpMessageHeaderAccessor.create(SimpMessageType.UNSUBSCRIBE); |
||||
|
||||
interceptor.preSend(message(), channel); |
||||
} |
||||
|
||||
@Test(expected = InvalidCsrfTokenException.class) |
||||
public void preSendNoToken() { |
||||
messageHeaders.removeNativeHeader(token.getHeaderName()); |
||||
|
||||
interceptor.preSend(message(), channel); |
||||
} |
||||
|
||||
@Test(expected = InvalidCsrfTokenException.class) |
||||
public void preSendInvalidToken() { |
||||
messageHeaders.setNativeHeader(token.getHeaderName(), token.getToken() + "invalid"); |
||||
|
||||
interceptor.preSend(message(), channel); |
||||
} |
||||
|
||||
@Test(expected = MissingCsrfTokenException.class) |
||||
public void preSendMissingToken() { |
||||
messageHeaders.getSessionAttributes().clear(); |
||||
|
||||
interceptor.preSend(message(), channel); |
||||
} |
||||
|
||||
@Test(expected = MissingCsrfTokenException.class) |
||||
public void preSendMissingTokenNullSessionAttributes() { |
||||
messageHeaders.setSessionAttributes(null); |
||||
|
||||
interceptor.preSend(message(), channel); |
||||
} |
||||
|
||||
private Message<String> message() { |
||||
Map<String, Object> headersToCopy = messageHeaders.toMap(); |
||||
return MessageBuilder |
||||
.withPayload("hi") |
||||
.copyHeaders(headersToCopy) |
||||
.build(); |
||||
} |
||||
} |
||||
@ -0,0 +1,83 @@
@@ -0,0 +1,83 @@
|
||||
/* |
||||
* Copyright 2002-2015 the original author or authors. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); you may not |
||||
* use this file except in compliance with the License. You may obtain a copy of |
||||
* the License at |
||||
* |
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |
||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |
||||
* License for the specific language governing permissions and limitations under |
||||
* the License. |
||||
*/ |
||||
package org.springframework.security.messaging.web.socket.server; |
||||
|
||||
import org.junit.Test; |
||||
import org.junit.Before; |
||||
import org.junit.runner.RunWith; |
||||
import org.mockito.Mock; |
||||
import org.mockito.runners.MockitoJUnitRunner; |
||||
import org.springframework.http.server.ServerHttpRequest; |
||||
import org.springframework.http.server.ServerHttpResponse; |
||||
import org.springframework.http.server.ServletServerHttpRequest; |
||||
import org.springframework.mock.web.MockHttpServletRequest; |
||||
import org.springframework.security.web.csrf.CsrfToken; |
||||
import org.springframework.security.web.csrf.DefaultCsrfToken; |
||||
import org.springframework.web.socket.WebSocketHandler; |
||||
|
||||
import java.util.HashMap; |
||||
import java.util.Map; |
||||
|
||||
import static org.fest.assertions.Assertions.assertThat; |
||||
|
||||
|
||||
/** |
||||
* |
||||
* @author Rob Winch |
||||
*/ |
||||
@RunWith(MockitoJUnitRunner.class) |
||||
public class CsrfTokenHandshakeInterceptorTests { |
||||
@Mock |
||||
WebSocketHandler wsHandler; |
||||
@Mock |
||||
ServerHttpResponse response; |
||||
|
||||
Map<String, Object> attributes; |
||||
|
||||
ServerHttpRequest request; |
||||
|
||||
MockHttpServletRequest httpRequest; |
||||
|
||||
CsrfTokenHandshakeInterceptor interceptor; |
||||
|
||||
@Before |
||||
public void setup() { |
||||
httpRequest = new MockHttpServletRequest(); |
||||
attributes = new HashMap<String,Object>(); |
||||
request = new ServletServerHttpRequest(httpRequest); |
||||
|
||||
interceptor = new CsrfTokenHandshakeInterceptor(); |
||||
} |
||||
|
||||
@Test |
||||
public void beforeHandshakeNoAttribute() throws Exception { |
||||
interceptor.beforeHandshake(request, response, wsHandler, attributes); |
||||
|
||||
assertThat(attributes).isEmpty(); |
||||
} |
||||
|
||||
@Test |
||||
public void beforeHandshake() throws Exception { |
||||
CsrfToken token = new DefaultCsrfToken("header", "param", "token"); |
||||
httpRequest.setAttribute(CsrfToken.class.getName(), token); |
||||
|
||||
interceptor.beforeHandshake(request, response, wsHandler, attributes); |
||||
|
||||
assertThat(attributes.keySet()).containsOnly(CsrfToken.class.getName()); |
||||
assertThat(attributes.values()).containsOnly(token); |
||||
} |
||||
|
||||
} |
||||
@ -0,0 +1,33 @@
@@ -0,0 +1,33 @@
|
||||
/* |
||||
* Copyright 2002-2015 the original author or authors. |
||||
* |
||||
* Licensed under the Apache License, Version 2.0 (the "License"); |
||||
* you may not use this file except in compliance with the License. |
||||
* You may obtain a copy of the License at |
||||
* |
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* |
||||
* Unless required by applicable law or agreed to in writing, software |
||||
* distributed under the License is distributed on an "AS IS" BASIS, |
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
* See the License for the specific language governing permissions and |
||||
* limitations under the License. |
||||
*/ |
||||
package org.springframework.security.web.csrf; |
||||
|
||||
import org.junit.Test; |
||||
|
||||
/** |
||||
* |
||||
* @author Rob Winch |
||||
* |
||||
*/ |
||||
public class MissingCsrfTokenExceptionTests { |
||||
|
||||
// CsrfChannelInterceptor requires this to work
|
||||
@Test |
||||
public void nullExpectedTokenDoesNotFail() { |
||||
new MissingCsrfTokenException(null); |
||||
} |
||||
|
||||
} |
||||
Loading…
Reference in new issue