diff --git a/config/src/main/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurer.java b/config/src/main/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurer.java index 9c96ff4300..75430a4171 100644 --- a/config/src/main/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurer.java +++ b/config/src/main/java/org/springframework/security/config/annotation/web/socket/AbstractSecurityWebSocketMessageBrokerConfigurer.java @@ -22,18 +22,14 @@ import org.springframework.messaging.simp.config.ChannelRegistration; import org.springframework.security.access.AccessDecisionVoter; import org.springframework.security.access.vote.AffirmativeBased; import org.springframework.security.config.annotation.web.messaging.MessageSecurityMetadataSourceRegistry; -import org.springframework.security.messaging.access.expression.ExpressionBasedMessageSecurityMetadataSourceFactory; import org.springframework.security.messaging.access.expression.MessageExpressionVoter; import org.springframework.security.messaging.access.intercept.ChannelSecurityInterceptor; import org.springframework.security.messaging.access.intercept.MessageSecurityMetadataSource; import org.springframework.security.messaging.context.SecurityContextChannelInterceptor; -import org.springframework.security.messaging.util.matcher.MessageMatcher; -import org.springframework.security.messaging.util.matcher.SimpDestinationMessageMatcher; import org.springframework.web.socket.config.annotation.AbstractWebSocketMessageBrokerConfigurer; import org.springframework.web.socket.config.annotation.StompEndpointRegistry; import java.util.ArrayList; -import java.util.LinkedHashMap; import java.util.List; /** @@ -46,7 +42,7 @@ import java.util.List; * public class WebSocketSecurityConfig extends AbstractSecurityWebSocketMessageBrokerConfigurer { * * @Override - * protected void configure(MessageSecurityMetadataSourceRegistry messages) { + * protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) { * messages * .antMatchers("/user/queue/errors").permitAll() * .antMatchers("/admin/**").hasRole("ADMIN") @@ -61,22 +57,40 @@ import java.util.List; */ @Order(Ordered.HIGHEST_PRECEDENCE + 100) public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends AbstractWebSocketMessageBrokerConfigurer { + private final WebSocketMessageSecurityMetadataSourceRegistry inboundRegistry = new WebSocketMessageSecurityMetadataSourceRegistry(); + private final WebSocketMessageSecurityMetadataSourceRegistry outboundRegistry = new WebSocketMessageSecurityMetadataSourceRegistry(); - public void registerStompEndpoints(StompEndpointRegistry registry) {} + public final void registerStompEndpoints(StompEndpointRegistry registry) {} @Override - public void configureClientInboundChannel(ChannelRegistration registration) { - registration.setInterceptors(securityContextChannelInterceptor(),channelSecurity()); + public final void configureClientInboundChannel(ChannelRegistration registration) { + ChannelSecurityInterceptor inboundChannelSecurity = inboundChannelSecurity(); + if(inboundRegistry.containsMapping()) { + registration.setInterceptors(securityContextChannelInterceptor(),inboundChannelSecurity); + } } @Override - public void configureClientOutboundChannel(ChannelRegistration registration) { - registration.setInterceptors(securityContextChannelInterceptor(),channelSecurity()); + public final void configureClientOutboundChannel(ChannelRegistration registration) { + ChannelSecurityInterceptor outboundChannelSecurity = outboundChannelSecurity(); + if(outboundRegistry.containsMapping()) { + registration.setInterceptors(securityContextChannelInterceptor(),outboundChannelSecurity); + } } @Bean - public ChannelSecurityInterceptor channelSecurity() { - ChannelSecurityInterceptor channelSecurityInterceptor = new ChannelSecurityInterceptor(metadataSource()); + public ChannelSecurityInterceptor inboundChannelSecurity() { + ChannelSecurityInterceptor channelSecurityInterceptor = new ChannelSecurityInterceptor(inboundMessageSecurityMetadataSource()); + List voters = new ArrayList(); + voters.add(new MessageExpressionVoter()); + AffirmativeBased manager = new AffirmativeBased(voters); + channelSecurityInterceptor.setAccessDecisionManager(manager); + return channelSecurityInterceptor; + } + + @Bean + public ChannelSecurityInterceptor outboundChannelSecurity() { + ChannelSecurityInterceptor channelSecurityInterceptor = new ChannelSecurityInterceptor(outboundMessageSecurityMetadataSource()); List voters = new ArrayList(); voters.add(new MessageExpressionVoter()); AffirmativeBased manager = new AffirmativeBased(voters); @@ -90,22 +104,38 @@ public abstract class AbstractSecurityWebSocketMessageBrokerConfigurer extends A } @Bean - public MessageSecurityMetadataSource metadataSource() { - WebSocketMessageSecurityMetadataSourceRegistry registry = new WebSocketMessageSecurityMetadataSourceRegistry(); - configure(registry); - return registry.createMetadataSource(); + public MessageSecurityMetadataSource inboundMessageSecurityMetadataSource() { + configureInbound(inboundRegistry); + return inboundRegistry.createMetadataSource(); } + @Bean + public MessageSecurityMetadataSource outboundMessageSecurityMetadataSource() { + configureOutbound(outboundRegistry); + return outboundRegistry.createMetadataSource(); + } + + /** + * + * @param messages + */ + protected void configureInbound(MessageSecurityMetadataSourceRegistry messages) {} + /** * * @param messages */ - protected abstract void configure(MessageSecurityMetadataSourceRegistry messages); + protected void configureOutbound(MessageSecurityMetadataSourceRegistry messages) {} private class WebSocketMessageSecurityMetadataSourceRegistry extends MessageSecurityMetadataSourceRegistry { @Override public MessageSecurityMetadataSource createMetadataSource() { return super.createMetadataSource(); } + + @Override + protected boolean containsMapping() { + return super.containsMapping(); + } } } diff --git a/messaging/src/test/java/org/springframework/security/messaging/access/expression/ExpressionBasedMessageSecurityMetadataSourceFactoryTests.java b/messaging/src/test/java/org/springframework/security/messaging/access/expression/ExpressionBasedMessageSecurityMetadataSourceFactoryTests.java index aacdaa56dc..f6d3a9c78b 100644 --- a/messaging/src/test/java/org/springframework/security/messaging/access/expression/ExpressionBasedMessageSecurityMetadataSourceFactoryTests.java +++ b/messaging/src/test/java/org/springframework/security/messaging/access/expression/ExpressionBasedMessageSecurityMetadataSourceFactoryTests.java @@ -16,9 +16,6 @@ package org.springframework.security.messaging.access.expression; import static org.fest.assertions.Assertions.assertThat; -import static org.mockito.Matchers.any; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; import static org.powermock.api.mockito.PowerMockito.when; import static org.springframework.security.messaging.access.expression.ExpressionBasedMessageSecurityMetadataSourceFactory.*; @@ -39,11 +36,11 @@ import java.util.LinkedHashMap; @RunWith(MockitoJUnitRunner.class) public class ExpressionBasedMessageSecurityMetadataSourceFactoryTests { @Mock - MessageMatcher matcher1; + MessageMatcher matcher1; @Mock - MessageMatcher matcher2; + MessageMatcher matcher2; @Mock - Message message; + Message message; @Mock Authentication authentication;