Browse Source

Polish spring-security-messaging main code

Manually polish `spring-security-messaging` following the formatting
and checkstyle fixes.

Issue gh-8945
pull/8983/head
Phillip Webb 6 years ago committed by Rob Winch
parent
commit
ad1dbf425f
  1. 3
      messaging/src/main/java/org/springframework/security/messaging/access/expression/EvaluationContextPostProcessor.java
  2. 8
      messaging/src/main/java/org/springframework/security/messaging/access/expression/ExpressionBasedMessageSecurityMetadataSourceFactory.java
  3. 2
      messaging/src/main/java/org/springframework/security/messaging/access/expression/MessageExpressionConfigAttribute.java
  4. 10
      messaging/src/main/java/org/springframework/security/messaging/access/expression/MessageExpressionVoter.java
  5. 2
      messaging/src/main/java/org/springframework/security/messaging/access/intercept/DefaultMessageSecurityMetadataSource.java
  6. 8
      messaging/src/main/java/org/springframework/security/messaging/context/AuthenticationPrincipalArgumentResolver.java
  7. 41
      messaging/src/main/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptor.java
  8. 9
      messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/AuthenticationPrincipalArgumentResolver.java
  9. 9
      messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolver.java
  10. 12
      messaging/src/main/java/org/springframework/security/messaging/util/matcher/AbstractMessageMatcherComposite.java
  11. 9
      messaging/src/main/java/org/springframework/security/messaging/util/matcher/AndMessageMatcher.java
  12. 16
      messaging/src/main/java/org/springframework/security/messaging/util/matcher/MessageMatcher.java
  13. 9
      messaging/src/main/java/org/springframework/security/messaging/util/matcher/OrMessageMatcher.java
  14. 13
      messaging/src/main/java/org/springframework/security/messaging/util/matcher/SimpDestinationMessageMatcher.java
  15. 2
      messaging/src/main/java/org/springframework/security/messaging/util/matcher/SimpMessageTypeMatcher.java
  16. 10
      messaging/src/main/java/org/springframework/security/messaging/web/csrf/CsrfChannelInterceptor.java

3
messaging/src/main/java/org/springframework/security/messaging/access/expression/EvaluationContextPostProcessor.java

@ -19,8 +19,7 @@ package org.springframework.security.messaging.access.expression;
import org.springframework.expression.EvaluationContext; import org.springframework.expression.EvaluationContext;
/** /**
* * Allows post processing the {@link EvaluationContext}
* /** Allows post processing the {@link EvaluationContext}
* *
* <p> * <p>
* This API is intentionally kept package scope as it may evolve over time. * This API is intentionally kept package scope as it may evolve over time.

8
messaging/src/main/java/org/springframework/security/messaging/access/expression/ExpressionBasedMessageSecurityMetadataSourceFactory.java

@ -38,6 +38,9 @@ import org.springframework.security.messaging.util.matcher.MessageMatcher;
*/ */
public final class ExpressionBasedMessageSecurityMetadataSourceFactory { public final class ExpressionBasedMessageSecurityMetadataSourceFactory {
private ExpressionBasedMessageSecurityMetadataSourceFactory() {
}
/** /**
* Create a {@link MessageSecurityMetadataSource} that uses {@link MessageMatcher} * Create a {@link MessageSecurityMetadataSource} that uses {@link MessageMatcher}
* mapped to Spring Expressions. Each entry is considered in order and only the first * mapped to Spring Expressions. Each entry is considered in order and only the first
@ -108,9 +111,7 @@ public final class ExpressionBasedMessageSecurityMetadataSourceFactory {
public static MessageSecurityMetadataSource createExpressionMessageMetadataSource( public static MessageSecurityMetadataSource createExpressionMessageMetadataSource(
LinkedHashMap<MessageMatcher<?>, String> matcherToExpression, LinkedHashMap<MessageMatcher<?>, String> matcherToExpression,
SecurityExpressionHandler<Message<Object>> handler) { SecurityExpressionHandler<Message<Object>> handler) {
LinkedHashMap<MessageMatcher<?>, Collection<ConfigAttribute>> matcherToAttrs = new LinkedHashMap<>(); LinkedHashMap<MessageMatcher<?>, Collection<ConfigAttribute>> matcherToAttrs = new LinkedHashMap<>();
for (Map.Entry<MessageMatcher<?>, String> entry : matcherToExpression.entrySet()) { for (Map.Entry<MessageMatcher<?>, String> entry : matcherToExpression.entrySet()) {
MessageMatcher<?> matcher = entry.getKey(); MessageMatcher<?> matcher = entry.getKey();
String rawExpression = entry.getValue(); String rawExpression = entry.getValue();
@ -121,7 +122,4 @@ public final class ExpressionBasedMessageSecurityMetadataSourceFactory {
return new DefaultMessageSecurityMetadataSource(matcherToAttrs); return new DefaultMessageSecurityMetadataSource(matcherToAttrs);
} }
private ExpressionBasedMessageSecurityMetadataSourceFactory() {
}
} }

2
messaging/src/main/java/org/springframework/security/messaging/access/expression/MessageExpressionConfigAttribute.java

@ -69,7 +69,7 @@ class MessageExpressionConfigAttribute implements ConfigAttribute, EvaluationCon
@Override @Override
public EvaluationContext postProcess(EvaluationContext ctx, Message<?> message) { public EvaluationContext postProcess(EvaluationContext ctx, Message<?> message) {
if (this.matcher instanceof SimpDestinationMessageMatcher) { if (this.matcher instanceof SimpDestinationMessageMatcher) {
final Map<String, String> variables = ((SimpDestinationMessageMatcher) this.matcher) Map<String, String> variables = ((SimpDestinationMessageMatcher) this.matcher)
.extractPathVariables(message); .extractPathVariables(message);
for (Map.Entry<String, String> entry : variables.entrySet()) { for (Map.Entry<String, String> entry : variables.entrySet()) {
ctx.setVariable(entry.getKey(), entry.getValue()); ctx.setVariable(entry.getKey(), entry.getValue());

10
messaging/src/main/java/org/springframework/security/messaging/access/expression/MessageExpressionVoter.java

@ -44,19 +44,15 @@ public class MessageExpressionVoter<T> implements AccessDecisionVoter<Message<T>
@Override @Override
public int vote(Authentication authentication, Message<T> message, Collection<ConfigAttribute> attributes) { public int vote(Authentication authentication, Message<T> message, Collection<ConfigAttribute> attributes) {
assert authentication != null; Assert.notNull(authentication, "authentication must not be null");
assert message != null; Assert.notNull(message, "message must not be null");
assert attributes != null; Assert.notNull(attributes, "attributes must not be null");
MessageExpressionConfigAttribute attr = findConfigAttribute(attributes); MessageExpressionConfigAttribute attr = findConfigAttribute(attributes);
if (attr == null) { if (attr == null) {
return ACCESS_ABSTAIN; return ACCESS_ABSTAIN;
} }
EvaluationContext ctx = this.expressionHandler.createEvaluationContext(authentication, message); EvaluationContext ctx = this.expressionHandler.createEvaluationContext(authentication, message);
ctx = attr.postProcess(ctx, message); ctx = attr.postProcess(ctx, message);
return ExpressionUtils.evaluateAsBoolean(attr.getAuthorizeExpression(), ctx) ? ACCESS_GRANTED : ACCESS_DENIED; return ExpressionUtils.evaluateAsBoolean(attr.getAuthorizeExpression(), ctx) ? ACCESS_GRANTED : ACCESS_DENIED;
} }

2
messaging/src/main/java/org/springframework/security/messaging/access/intercept/DefaultMessageSecurityMetadataSource.java

@ -65,11 +65,9 @@ public final class DefaultMessageSecurityMetadataSource implements MessageSecuri
@Override @Override
public Collection<ConfigAttribute> getAllConfigAttributes() { public Collection<ConfigAttribute> getAllConfigAttributes() {
Set<ConfigAttribute> allAttributes = new HashSet<>(); Set<ConfigAttribute> allAttributes = new HashSet<>();
for (Collection<ConfigAttribute> entry : this.messageMap.values()) { for (Collection<ConfigAttribute> entry : this.messageMap.values()) {
allAttributes.addAll(entry); allAttributes.addAll(entry);
} }
return allAttributes; return allAttributes;
} }

8
messaging/src/main/java/org/springframework/security/messaging/context/AuthenticationPrincipalArgumentResolver.java

@ -98,26 +98,20 @@ public final class AuthenticationPrincipalArgumentResolver implements HandlerMet
return null; return null;
} }
Object principal = authentication.getPrincipal(); Object principal = authentication.getPrincipal();
AuthenticationPrincipal authPrincipal = findMethodAnnotation(AuthenticationPrincipal.class, parameter); AuthenticationPrincipal authPrincipal = findMethodAnnotation(AuthenticationPrincipal.class, parameter);
String expressionToParse = authPrincipal.expression(); String expressionToParse = authPrincipal.expression();
if (StringUtils.hasLength(expressionToParse)) { if (StringUtils.hasLength(expressionToParse)) {
StandardEvaluationContext context = new StandardEvaluationContext(); StandardEvaluationContext context = new StandardEvaluationContext();
context.setRootObject(principal); context.setRootObject(principal);
context.setVariable("this", principal); context.setVariable("this", principal);
Expression expression = this.parser.parseExpression(expressionToParse); Expression expression = this.parser.parseExpression(expressionToParse);
principal = expression.getValue(context); principal = expression.getValue(context);
} }
if (principal != null && !parameter.getParameterType().isAssignableFrom(principal.getClass())) { if (principal != null && !parameter.getParameterType().isAssignableFrom(principal.getClass())) {
if (authPrincipal.errorOnInvalidType()) { if (authPrincipal.errorOnInvalidType()) {
throw new ClassCastException(principal + " is not assignable to " + parameter.getParameterType()); throw new ClassCastException(principal + " is not assignable to " + parameter.getParameterType());
} }
else { return null;
return null;
}
} }
return principal; return principal;
} }

41
messaging/src/main/java/org/springframework/security/messaging/context/SecurityContextChannelInterceptor.java

@ -43,9 +43,9 @@ import org.springframework.util.Assert;
public final class SecurityContextChannelInterceptor extends ChannelInterceptorAdapter public final class SecurityContextChannelInterceptor extends ChannelInterceptorAdapter
implements ExecutorChannelInterceptor { implements ExecutorChannelInterceptor {
private final SecurityContext EMPTY_CONTEXT = SecurityContextHolder.createEmptyContext(); private static final SecurityContext EMPTY_CONTEXT = SecurityContextHolder.createEmptyContext();
private static final ThreadLocal<Stack<SecurityContext>> ORIGINAL_CONTEXT = new ThreadLocal<>(); private static final ThreadLocal<Stack<SecurityContext>> originalContext = new ThreadLocal<>();
private final String authenticationHeaderName; private final String authenticationHeaderName;
@ -110,46 +110,41 @@ public final class SecurityContextChannelInterceptor extends ChannelInterceptorA
private void setup(Message<?> message) { private void setup(Message<?> message) {
SecurityContext currentContext = SecurityContextHolder.getContext(); SecurityContext currentContext = SecurityContextHolder.getContext();
Stack<SecurityContext> contextStack = originalContext.get();
Stack<SecurityContext> contextStack = ORIGINAL_CONTEXT.get();
if (contextStack == null) { if (contextStack == null) {
contextStack = new Stack<>(); contextStack = new Stack<>();
ORIGINAL_CONTEXT.set(contextStack); originalContext.set(contextStack);
} }
contextStack.push(currentContext); contextStack.push(currentContext);
Object user = message.getHeaders().get(this.authenticationHeaderName); Object user = message.getHeaders().get(this.authenticationHeaderName);
Authentication authentication = getAuthentication(user);
Authentication authentication;
if ((user instanceof Authentication)) {
authentication = (Authentication) user;
}
else {
authentication = this.anonymous;
}
SecurityContext context = SecurityContextHolder.createEmptyContext(); SecurityContext context = SecurityContextHolder.createEmptyContext();
context.setAuthentication(authentication); context.setAuthentication(authentication);
SecurityContextHolder.setContext(context); SecurityContextHolder.setContext(context);
} }
private void cleanup() { private Authentication getAuthentication(Object user) {
Stack<SecurityContext> contextStack = ORIGINAL_CONTEXT.get(); if ((user instanceof Authentication)) {
return (Authentication) user;
}
return this.anonymous;
}
private void cleanup() {
Stack<SecurityContext> contextStack = originalContext.get();
if (contextStack == null || contextStack.isEmpty()) { if (contextStack == null || contextStack.isEmpty()) {
SecurityContextHolder.clearContext(); SecurityContextHolder.clearContext();
ORIGINAL_CONTEXT.remove(); originalContext.remove();
return; return;
} }
SecurityContext context = contextStack.pop();
SecurityContext originalContext = contextStack.pop();
try { try {
if (this.EMPTY_CONTEXT.equals(originalContext)) { if (SecurityContextChannelInterceptor.EMPTY_CONTEXT.equals(context)) {
SecurityContextHolder.clearContext(); SecurityContextHolder.clearContext();
ORIGINAL_CONTEXT.remove(); originalContext.remove();
} }
else { else {
SecurityContextHolder.setContext(originalContext); SecurityContextHolder.setContext(context);
} }
} }
catch (Throwable ex) { catch (Throwable ex) {

9
messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/AuthenticationPrincipalArgumentResolver.java

@ -134,28 +134,21 @@ public class AuthenticationPrincipalArgumentResolver implements HandlerMethodArg
private Object resolvePrincipal(MethodParameter parameter, Object principal) { private Object resolvePrincipal(MethodParameter parameter, Object principal) {
AuthenticationPrincipal authPrincipal = findMethodAnnotation(AuthenticationPrincipal.class, parameter); AuthenticationPrincipal authPrincipal = findMethodAnnotation(AuthenticationPrincipal.class, parameter);
String expressionToParse = authPrincipal.expression(); String expressionToParse = authPrincipal.expression();
if (StringUtils.hasLength(expressionToParse)) { if (StringUtils.hasLength(expressionToParse)) {
StandardEvaluationContext context = new StandardEvaluationContext(); StandardEvaluationContext context = new StandardEvaluationContext();
context.setRootObject(principal); context.setRootObject(principal);
context.setVariable("this", principal); context.setVariable("this", principal);
context.setBeanResolver(this.beanResolver); context.setBeanResolver(this.beanResolver);
Expression expression = this.parser.parseExpression(expressionToParse); Expression expression = this.parser.parseExpression(expressionToParse);
principal = expression.getValue(context); principal = expression.getValue(context);
} }
if (isInvalidType(parameter, principal)) { if (isInvalidType(parameter, principal)) {
if (authPrincipal.errorOnInvalidType()) { if (authPrincipal.errorOnInvalidType()) {
throw new ClassCastException(principal + " is not assignable to " + parameter.getParameterType()); throw new ClassCastException(principal + " is not assignable to " + parameter.getParameterType());
} }
else { return null;
return null;
}
} }
return principal; return principal;
} }

9
messaging/src/main/java/org/springframework/security/messaging/handler/invocation/reactive/CurrentSecurityContextArgumentResolver.java

@ -133,28 +133,21 @@ public class CurrentSecurityContextArgumentResolver implements HandlerMethodArgu
private Object resolveSecurityContext(MethodParameter parameter, Object securityContext) { private Object resolveSecurityContext(MethodParameter parameter, Object securityContext) {
CurrentSecurityContext contextAnno = findMethodAnnotation(CurrentSecurityContext.class, parameter); CurrentSecurityContext contextAnno = findMethodAnnotation(CurrentSecurityContext.class, parameter);
String expressionToParse = contextAnno.expression(); String expressionToParse = contextAnno.expression();
if (StringUtils.hasLength(expressionToParse)) { if (StringUtils.hasLength(expressionToParse)) {
StandardEvaluationContext context = new StandardEvaluationContext(); StandardEvaluationContext context = new StandardEvaluationContext();
context.setRootObject(securityContext); context.setRootObject(securityContext);
context.setVariable("this", securityContext); context.setVariable("this", securityContext);
context.setBeanResolver(this.beanResolver); context.setBeanResolver(this.beanResolver);
Expression expression = this.parser.parseExpression(expressionToParse); Expression expression = this.parser.parseExpression(expressionToParse);
securityContext = expression.getValue(context); securityContext = expression.getValue(context);
} }
if (isInvalidType(parameter, securityContext)) { if (isInvalidType(parameter, securityContext)) {
if (contextAnno.errorOnInvalidType()) { if (contextAnno.errorOnInvalidType()) {
throw new ClassCastException(securityContext + " is not assignable to " + parameter.getParameterType()); throw new ClassCastException(securityContext + " is not assignable to " + parameter.getParameterType());
} }
else { return null;
return null;
}
} }
return securityContext; return securityContext;
} }

12
messaging/src/main/java/org/springframework/security/messaging/util/matcher/AbstractMessageMatcherComposite.java

@ -31,7 +31,13 @@ import org.springframework.util.Assert;
*/ */
public abstract class AbstractMessageMatcherComposite<T> implements MessageMatcher<T> { public abstract class AbstractMessageMatcherComposite<T> implements MessageMatcher<T> {
protected final Log LOGGER = LogFactory.getLog(getClass()); protected final Log logger = LogFactory.getLog(getClass());
/**
* @deprecated since 5.4 in favor of {@link #logger}
*/
@Deprecated
protected final Log LOGGER = this.logger;
private final List<MessageMatcher<T>> messageMatchers; private final List<MessageMatcher<T>> messageMatchers;
@ -41,9 +47,7 @@ public abstract class AbstractMessageMatcherComposite<T> implements MessageMatch
*/ */
AbstractMessageMatcherComposite(List<MessageMatcher<T>> messageMatchers) { AbstractMessageMatcherComposite(List<MessageMatcher<T>> messageMatchers) {
Assert.notEmpty(messageMatchers, "messageMatchers must contain a value"); Assert.notEmpty(messageMatchers, "messageMatchers must contain a value");
if (messageMatchers.contains(null)) { Assert.isTrue(!messageMatchers.contains(null), "messageMatchers cannot contain null values");
throw new IllegalArgumentException("messageMatchers cannot contain null values");
}
this.messageMatchers = messageMatchers; this.messageMatchers = messageMatchers;
} }

9
messaging/src/main/java/org/springframework/security/messaging/util/matcher/AndMessageMatcher.java

@ -18,6 +18,7 @@ package org.springframework.security.messaging.util.matcher;
import java.util.List; import java.util.List;
import org.springframework.core.log.LogMessage;
import org.springframework.messaging.Message; import org.springframework.messaging.Message;
/** /**
@ -49,15 +50,13 @@ public final class AndMessageMatcher<T> extends AbstractMessageMatcherComposite<
@Override @Override
public boolean matches(Message<? extends T> message) { public boolean matches(Message<? extends T> message) {
for (MessageMatcher<T> matcher : getMessageMatchers()) { for (MessageMatcher<T> matcher : getMessageMatchers()) {
if (this.LOGGER.isDebugEnabled()) { this.logger.debug(LogMessage.format("Trying to match using %s", matcher));
this.LOGGER.debug("Trying to match using " + matcher);
}
if (!matcher.matches(message)) { if (!matcher.matches(message)) {
this.LOGGER.debug("Did not match"); this.logger.debug("Did not match");
return false; return false;
} }
} }
this.LOGGER.debug("All messageMatchers returned true"); this.logger.debug("All messageMatchers returned true");
return true; return true;
} }

16
messaging/src/main/java/org/springframework/security/messaging/util/matcher/MessageMatcher.java

@ -26,17 +26,11 @@ import org.springframework.messaging.Message;
*/ */
public interface MessageMatcher<T> { public interface MessageMatcher<T> {
/**
* Returns true if the {@link Message} matches, else false
* @param message the {@link Message} to match on
* @return true if the {@link Message} matches, else false
*/
boolean matches(Message<? extends T> message);
/** /**
* Matches every {@link Message} * Matches every {@link Message}
*/ */
MessageMatcher<Object> ANY_MESSAGE = new MessageMatcher<Object>() { MessageMatcher<Object> ANY_MESSAGE = new MessageMatcher<Object>() {
@Override @Override
public boolean matches(Message<?> message) { public boolean matches(Message<?> message) {
return true; return true;
@ -46,6 +40,14 @@ public interface MessageMatcher<T> {
public String toString() { public String toString() {
return "ANY_MESSAGE"; return "ANY_MESSAGE";
} }
}; };
/**
* Returns true if the {@link Message} matches, else false
* @param message the {@link Message} to match on
* @return true if the {@link Message} matches, else false
*/
boolean matches(Message<? extends T> message);
} }

9
messaging/src/main/java/org/springframework/security/messaging/util/matcher/OrMessageMatcher.java

@ -18,6 +18,7 @@ package org.springframework.security.messaging.util.matcher;
import java.util.List; import java.util.List;
import org.springframework.core.log.LogMessage;
import org.springframework.messaging.Message; import org.springframework.messaging.Message;
/** /**
@ -49,15 +50,13 @@ public final class OrMessageMatcher<T> extends AbstractMessageMatcherComposite<T
@Override @Override
public boolean matches(Message<? extends T> message) { public boolean matches(Message<? extends T> message) {
for (MessageMatcher<T> matcher : getMessageMatchers()) { for (MessageMatcher<T> matcher : getMessageMatchers()) {
if (this.LOGGER.isDebugEnabled()) { this.logger.debug(LogMessage.format("Trying to match using %s", matcher));
this.LOGGER.debug("Trying to match using " + matcher);
}
if (matcher.matches(message)) { if (matcher.matches(message)) {
this.LOGGER.debug("matched"); this.logger.debug("matched");
return true; return true;
} }
} }
this.LOGGER.debug("No matches found"); this.logger.debug("No matches found");
return false; return false;
} }

13
messaging/src/main/java/org/springframework/security/messaging/util/matcher/SimpDestinationMessageMatcher.java

@ -107,11 +107,8 @@ public final class SimpDestinationMessageMatcher implements MessageMatcher<Objec
private SimpDestinationMessageMatcher(String pattern, SimpMessageType type, PathMatcher pathMatcher) { private SimpDestinationMessageMatcher(String pattern, SimpMessageType type, PathMatcher pathMatcher) {
Assert.notNull(pattern, "pattern cannot be null"); Assert.notNull(pattern, "pattern cannot be null");
Assert.notNull(pathMatcher, "pathMatcher cannot be null"); Assert.notNull(pathMatcher, "pathMatcher cannot be null");
if (!isTypeWithDestination(type)) { Assert.isTrue(isTypeWithDestination(type),
throw new IllegalArgumentException( () -> "SimpMessageType " + type + " does not contain a destination and so cannot be matched on.");
"SimpMessageType " + type + " does not contain a destination and so cannot be matched on.");
}
this.matcher = pathMatcher; this.matcher = pathMatcher;
this.messageTypeMatcher = (type != null) ? new SimpMessageTypeMatcher(type) : ANY_MESSAGE; this.messageTypeMatcher = (type != null) ? new SimpMessageTypeMatcher(type) : ANY_MESSAGE;
this.pattern = pattern; this.pattern = pattern;
@ -122,7 +119,6 @@ public final class SimpDestinationMessageMatcher implements MessageMatcher<Objec
if (!this.messageTypeMatcher.matches(message)) { if (!this.messageTypeMatcher.matches(message)) {
return false; return false;
} }
String destination = SimpMessageHeaderAccessor.getDestination(message.getHeaders()); String destination = SimpMessageHeaderAccessor.getDestination(message.getHeaders());
return destination != null && this.matcher.match(this.pattern, destination); return destination != null && this.matcher.match(this.pattern, destination);
} }
@ -144,10 +140,7 @@ public final class SimpDestinationMessageMatcher implements MessageMatcher<Objec
} }
private boolean isTypeWithDestination(SimpMessageType type) { private boolean isTypeWithDestination(SimpMessageType type) {
if (type == null) { return type == null || SimpMessageType.MESSAGE.equals(type) || SimpMessageType.SUBSCRIBE.equals(type);
return true;
}
return SimpMessageType.MESSAGE.equals(type) || SimpMessageType.SUBSCRIBE.equals(type);
} }
/** /**

2
messaging/src/main/java/org/springframework/security/messaging/util/matcher/SimpMessageTypeMatcher.java

@ -49,7 +49,6 @@ public class SimpMessageTypeMatcher implements MessageMatcher<Object> {
public boolean matches(Message<?> message) { public boolean matches(Message<?> message) {
MessageHeaders headers = message.getHeaders(); MessageHeaders headers = message.getHeaders();
SimpMessageType messageType = SimpMessageHeaderAccessor.getMessageType(headers); SimpMessageType messageType = SimpMessageHeaderAccessor.getMessageType(headers);
return this.typeToMatch == messageType; return this.typeToMatch == messageType;
} }
@ -63,7 +62,6 @@ public class SimpMessageTypeMatcher implements MessageMatcher<Object> {
} }
SimpMessageTypeMatcher otherMatcher = (SimpMessageTypeMatcher) other; SimpMessageTypeMatcher otherMatcher = (SimpMessageTypeMatcher) other;
return ObjectUtils.nullSafeEquals(this.typeToMatch, otherMatcher.typeToMatch); return ObjectUtils.nullSafeEquals(this.typeToMatch, otherMatcher.typeToMatch);
} }
@Override @Override

10
messaging/src/main/java/org/springframework/security/messaging/web/csrf/CsrfChannelInterceptor.java

@ -46,23 +46,19 @@ public final class CsrfChannelInterceptor extends ChannelInterceptorAdapter {
if (!this.matcher.matches(message)) { if (!this.matcher.matches(message)) {
return message; return message;
} }
Map<String, Object> sessionAttributes = SimpMessageHeaderAccessor.getSessionAttributes(message.getHeaders()); Map<String, Object> sessionAttributes = SimpMessageHeaderAccessor.getSessionAttributes(message.getHeaders());
CsrfToken expectedToken = (sessionAttributes != null) CsrfToken expectedToken = (sessionAttributes != null)
? (CsrfToken) sessionAttributes.get(CsrfToken.class.getName()) : null; ? (CsrfToken) sessionAttributes.get(CsrfToken.class.getName()) : null;
if (expectedToken == null) { if (expectedToken == null) {
throw new MissingCsrfTokenException(null); throw new MissingCsrfTokenException(null);
} }
String actualTokenValue = SimpMessageHeaderAccessor.wrap(message) String actualTokenValue = SimpMessageHeaderAccessor.wrap(message)
.getFirstNativeHeader(expectedToken.getHeaderName()); .getFirstNativeHeader(expectedToken.getHeaderName());
boolean csrfCheckPassed = expectedToken.getToken().equals(actualTokenValue); boolean csrfCheckPassed = expectedToken.getToken().equals(actualTokenValue);
if (csrfCheckPassed) { if (!csrfCheckPassed) {
return message; throw new InvalidCsrfTokenException(expectedToken, actualTokenValue);
} }
throw new InvalidCsrfTokenException(expectedToken, actualTokenValue); return message;
} }
} }

Loading…
Cancel
Save