diff --git a/web/src/test/groovy/org/springframework/security/web/context/AbstractSecurityWebApplicationInitializerTests.groovy b/web/src/test/groovy/org/springframework/security/web/context/AbstractSecurityWebApplicationInitializerTests.groovy deleted file mode 100644 index 4dd55fbe95..0000000000 --- a/web/src/test/groovy/org/springframework/security/web/context/AbstractSecurityWebApplicationInitializerTests.groovy +++ /dev/null @@ -1,308 +0,0 @@ -/* - * Copyright 2002-2013 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.springframework.security.web.context; - -import javax.servlet.DispatcherType -import javax.servlet.Filter -import javax.servlet.FilterRegistration -import javax.servlet.ServletContext -import javax.servlet.SessionTrackingMode - -import org.springframework.context.annotation.Configuration -import org.springframework.security.web.session.HttpSessionEventPublisher -import org.springframework.web.context.ContextLoaderListener -import org.springframework.web.filter.DelegatingFilterProxy - -import spock.lang.Specification - -/** - * @author Rob Winch - * - */ -class AbstractSecurityWebApplicationInitializerTests extends Specification { - def DEFAULT_DISPATCH = EnumSet.of(DispatcherType.REQUEST, DispatcherType.ERROR, DispatcherType.ASYNC) - - def defaults() { - setup: - ServletContext context = Mock() - FilterRegistration.Dynamic registration = Mock() - when: - new AbstractSecurityWebApplicationInitializer(){}.onStartup(context) - then: - 1 * context.addFilter("springSecurityFilterChain", {DelegatingFilterProxy f -> f.targetBeanName == "springSecurityFilterChain" && f.contextAttribute == null}) >> registration - 1 * registration.addMappingForUrlPatterns(DEFAULT_DISPATCH, false, "/*"); - 1 * registration.setAsyncSupported(true) - 0 * context.addListener(_) - } - - def "defaults with ContextLoaderListener"() { - setup: - ServletContext context = Mock() - FilterRegistration.Dynamic registration = Mock() - when: - new AbstractSecurityWebApplicationInitializer(MyRootConfiguration){}.onStartup(context) - then: - 1 * context.addFilter("springSecurityFilterChain", {DelegatingFilterProxy f -> f.targetBeanName == "springSecurityFilterChain" && f.contextAttribute == null}) >> registration - 1 * registration.addMappingForUrlPatterns(DEFAULT_DISPATCH, false, "/*"); - 1 * registration.setAsyncSupported(true) - 1 * context.addListener(_ as ContextLoaderListener) - } - - @Configuration - static class MyRootConfiguration {} - - def "enableHttpSessionEventPublisher() = true"() { - setup: - ServletContext context = Mock() - FilterRegistration.Dynamic registration = Mock() - when: - new AbstractSecurityWebApplicationInitializer(){ - protected boolean enableHttpSessionEventPublisher() { - return true; - } - }.onStartup(context) - then: - 1 * context.addFilter("springSecurityFilterChain", {DelegatingFilterProxy f -> f.targetBeanName == "springSecurityFilterChain" && f.contextAttribute == null}) >> registration - 1 * registration.addMappingForUrlPatterns(DEFAULT_DISPATCH, false, "/*"); - 1 * registration.setAsyncSupported(true) - 1 * context.addListener(HttpSessionEventPublisher.class.name) - } - - def "custom getSecurityDispatcherTypes()"() { - setup: - ServletContext context = Mock() - FilterRegistration.Dynamic registration = Mock() - when: - new AbstractSecurityWebApplicationInitializer(){ - protected EnumSet getSecurityDispatcherTypes() { - return EnumSet.of(DispatcherType.REQUEST, DispatcherType.ERROR, DispatcherType.FORWARD); - } - }.onStartup(context) - then: - 1 * context.addFilter("springSecurityFilterChain", {DelegatingFilterProxy f -> f.targetBeanName == "springSecurityFilterChain" && f.contextAttribute == null}) >> registration - 1 * registration.addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST, DispatcherType.ERROR, DispatcherType.FORWARD), false, "/*"); - 1 * registration.setAsyncSupported(true) - 0 * context.addListener(_) - } - - def "custom getDispatcherWebApplicationContextSuffix"() { - setup: - ServletContext context = Mock() - FilterRegistration.Dynamic registration = Mock() - when: - new AbstractSecurityWebApplicationInitializer(){ - protected String getDispatcherWebApplicationContextSuffix() { - return "dispatcher" - } - }.onStartup(context) - then: - 1 * context.addFilter("springSecurityFilterChain", {DelegatingFilterProxy f -> f.targetBeanName == "springSecurityFilterChain" && f.contextAttribute == "org.springframework.web.servlet.FrameworkServlet.CONTEXT.dispatcher"}) >> registration - 1 * registration.addMappingForUrlPatterns(DEFAULT_DISPATCH, false, "/*"); - 1 * registration.setAsyncSupported(true) - 0 * context.addListener(_) - } - - def "springSecurityFilterChain already registered"() { - setup: - ServletContext context = Mock() - FilterRegistration.Dynamic registration = Mock() - when: - new AbstractSecurityWebApplicationInitializer(){}.onStartup(context) - then: - 1 * context.addFilter("springSecurityFilterChain", {DelegatingFilterProxy f -> f.targetBeanName == "springSecurityFilterChain" && f.contextAttribute == null}) >> null - IllegalStateException success = thrown() - success.message == "Duplicate Filter registration for 'springSecurityFilterChain'. Check to ensure the Filter is only configured once." - } - - def "insertFilters"() { - setup: - Filter filter1 = Mock() - Filter filter2 = Mock() - ServletContext context = Mock() - FilterRegistration.Dynamic registration = Mock() - when: - new AbstractSecurityWebApplicationInitializer(){ - protected void afterSpringSecurityFilterChain(ServletContext servletContext) { - insertFilters(context, filter1, filter2); - } - }.onStartup(context) - then: - 1 * context.addFilter("springSecurityFilterChain", {DelegatingFilterProxy f -> f.targetBeanName == "springSecurityFilterChain" && f.contextAttribute == null}) >> registration - 3 * registration.addMappingForUrlPatterns(DEFAULT_DISPATCH, false, "/*"); - 3 * registration.setAsyncSupported(true) - 0 * context.addListener(_) - 1 * context.addFilter(_, filter1) >> registration - 1 * context.addFilter(_, filter2) >> registration - } - - def "insertFilters already registered"() { - setup: - Filter filter1 = Mock() - ServletContext context = Mock() - FilterRegistration.Dynamic registration = Mock() - when: - new AbstractSecurityWebApplicationInitializer(){ - protected void afterSpringSecurityFilterChain(ServletContext servletContext) { - insertFilters(context, filter1); - } - }.onStartup(context) - then: - 1 * context.addFilter("springSecurityFilterChain", {DelegatingFilterProxy f -> f.targetBeanName == "springSecurityFilterChain" && f.contextAttribute == null}) >> registration - 1 * registration.addMappingForUrlPatterns(DEFAULT_DISPATCH, false, "/*"); - 1 * context.addFilter(_, filter1) >> null - IllegalStateException success = thrown() - success.message == "Duplicate Filter registration for 'filter'. Check to ensure the Filter is only configured once." - } - - def "insertFilters no filters"() { - setup: - ServletContext context = Mock() - FilterRegistration.Dynamic registration = Mock() - when: - new AbstractSecurityWebApplicationInitializer(){ - protected void afterSpringSecurityFilterChain(ServletContext servletContext) { - insertFilters(context); - } - }.onStartup(context) - then: - 1 * context.addFilter("springSecurityFilterChain", {DelegatingFilterProxy f -> f.targetBeanName == "springSecurityFilterChain" && f.contextAttribute == null}) >> registration - IllegalArgumentException success = thrown() - success.message == "filters cannot be null or empty" - } - - def "insertFilters filters with null"() { - setup: - Filter filter1 = Mock() - ServletContext context = Mock() - FilterRegistration.Dynamic registration = Mock() - when: - new AbstractSecurityWebApplicationInitializer(){ - protected void afterSpringSecurityFilterChain(ServletContext servletContext) { - insertFilters(context, filter1, null); - } - }.onStartup(context) - then: - 2 * context.addFilter(_, _) >> registration - IllegalArgumentException success = thrown() - success.message == "filters cannot contain null values. Got [Mock for type 'Filter' named 'filter1', null]" - } - - def "appendFilters"() { - setup: - Filter filter1 = Mock() - Filter filter2 = Mock() - ServletContext context = Mock() - FilterRegistration.Dynamic registration = Mock() - when: - new AbstractSecurityWebApplicationInitializer(){ - protected void afterSpringSecurityFilterChain(ServletContext servletContext) { - appendFilters(context,filter1, filter2); - } - }.onStartup(context) - then: - 1 * context.addFilter("springSecurityFilterChain", {DelegatingFilterProxy f -> f.targetBeanName == "springSecurityFilterChain" && f.contextAttribute == null}) >> registration - 1 * registration.addMappingForUrlPatterns(DEFAULT_DISPATCH, false, "/*"); - 2 * registration.addMappingForUrlPatterns(DEFAULT_DISPATCH, true, "/*"); - 3 * registration.setAsyncSupported(true) - 0 * context.addListener(_) - 1 * context.addFilter(_, filter1) >> registration - 1 * context.addFilter(_, filter2) >> registration - } - - def "appendFilters already registered"() { - setup: - Filter filter1 = Mock() - ServletContext context = Mock() - FilterRegistration.Dynamic registration = Mock() - when: - new AbstractSecurityWebApplicationInitializer(){ - protected void afterSpringSecurityFilterChain(ServletContext servletContext) { - appendFilters(context, filter1); - } - }.onStartup(context) - then: - 1 * context.addFilter("springSecurityFilterChain", {DelegatingFilterProxy f -> f.targetBeanName == "springSecurityFilterChain" && f.contextAttribute == null}) >> registration - 1 * registration.addMappingForUrlPatterns(DEFAULT_DISPATCH, false, "/*"); - 1 * context.addFilter(_, filter1) >> null - IllegalStateException success = thrown() - success.message == "Duplicate Filter registration for 'filter'. Check to ensure the Filter is only configured once." - } - - def "appendFilters no filters"() { - setup: - ServletContext context = Mock() - FilterRegistration.Dynamic registration = Mock() - when: - new AbstractSecurityWebApplicationInitializer(){ - protected void afterSpringSecurityFilterChain(ServletContext servletContext) { - appendFilters(context); - } - }.onStartup(context) - then: - 1 * context.addFilter("springSecurityFilterChain", {DelegatingFilterProxy f -> f.targetBeanName == "springSecurityFilterChain" && f.contextAttribute == null}) >> registration - IllegalArgumentException success = thrown() - success.message == "filters cannot be null or empty" - } - - def "sessionTrackingModes defaults"() { - setup: - ServletContext context = Mock() - FilterRegistration.Dynamic registration = Mock() - when: - new AbstractSecurityWebApplicationInitializer(){ }.onStartup(context) - then: - 1 * context.addFilter("springSecurityFilterChain", {DelegatingFilterProxy f -> f.targetBeanName == "springSecurityFilterChain" && f.contextAttribute == null}) >> registration - 1 * context.setSessionTrackingModes({Set modes -> modes.size() == 1 && modes.containsAll([SessionTrackingMode.COOKIE]) }) - } - - def "sessionTrackingModes override"() { - setup: - ServletContext context = Mock() - FilterRegistration.Dynamic registration = Mock() - when: - new AbstractSecurityWebApplicationInitializer(){ - @Override - public Set getSessionTrackingModes() { - return [SessionTrackingMode.SSL] - } - }.onStartup(context) - then: - 1 * context.addFilter("springSecurityFilterChain", {DelegatingFilterProxy f -> f.targetBeanName == "springSecurityFilterChain" && f.contextAttribute == null}) >> registration - 1 * context.setSessionTrackingModes({Set modes -> modes.size() == 1 && modes.containsAll([SessionTrackingMode.SSL]) }) - } - - def "appendFilters filters with null"() { - setup: - Filter filter1 = Mock() - ServletContext context = Mock() - FilterRegistration.Dynamic registration = Mock() - when: - new AbstractSecurityWebApplicationInitializer(){ - protected void afterSpringSecurityFilterChain(ServletContext servletContext) { - appendFilters(context, filter1, null); - } - }.onStartup(context) - then: - 2 * context.addFilter(_, _) >> registration - IllegalArgumentException success = thrown() - success.message == "filters cannot contain null values. Got [Mock for type 'Filter' named 'filter1', null]" - } - - def "DEFAULT_FILTER_NAME == springSecurityFilterChain"() { - expect: - AbstractSecurityWebApplicationInitializer.DEFAULT_FILTER_NAME == "springSecurityFilterChain" - } -} diff --git a/web/src/test/java/org/springframework/security/web/context/AbstractSecurityWebApplicationInitializerTests.java b/web/src/test/java/org/springframework/security/web/context/AbstractSecurityWebApplicationInitializerTests.java new file mode 100644 index 0000000000..a96bf119c1 --- /dev/null +++ b/web/src/test/java/org/springframework/security/web/context/AbstractSecurityWebApplicationInitializerTests.java @@ -0,0 +1,447 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.security.web.context; + +import java.util.Collections; +import java.util.EnumSet; +import java.util.EventListener; +import java.util.HashSet; +import java.util.Set; +import javax.servlet.DispatcherType; +import javax.servlet.Filter; +import javax.servlet.FilterRegistration; +import javax.servlet.ServletContext; +import javax.servlet.SessionTrackingMode; + +import org.junit.Test; +import org.mockito.ArgumentCaptor; + +import org.springframework.context.annotation.Configuration; +import org.springframework.security.web.session.HttpSessionEventPublisher; +import org.springframework.web.context.ContextLoaderListener; +import org.springframework.web.filter.DelegatingFilterProxy; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatCode; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.powermock.api.mockito.PowerMockito.when; + +/** + * @author Rob Winch + * @author Josh Cummings + */ +public class AbstractSecurityWebApplicationInitializerTests { + private static final EnumSet DEFAULT_DISPATCH = + EnumSet.of(DispatcherType.REQUEST, DispatcherType.ERROR, DispatcherType.ASYNC); + + @Test + public void onStartupWhenDefaultContextThenRegistersSpringSecurityFilterChain() { + ServletContext context = mock(ServletContext.class); + FilterRegistration.Dynamic registration = mock(FilterRegistration.Dynamic.class); + + ArgumentCaptor proxyCaptor = ArgumentCaptor.forClass(DelegatingFilterProxy.class); + when(context.addFilter(eq("springSecurityFilterChain"), proxyCaptor.capture())) + .thenReturn(registration); + + new AbstractSecurityWebApplicationInitializer() {}.onStartup(context); + + assertProxyDefaults(proxyCaptor.getValue()); + + verify(registration).addMappingForUrlPatterns(DEFAULT_DISPATCH, false, "/*"); + verify(registration).setAsyncSupported(true); + verifyNoAddListener(context); + } + + @Test + public void onStartupWhenConfigurationClassThenAddsContextLoaderListener() { + ServletContext context = mock(ServletContext.class); + FilterRegistration.Dynamic registration = mock(FilterRegistration.Dynamic.class); + + ArgumentCaptor proxyCaptor = ArgumentCaptor.forClass(DelegatingFilterProxy.class); + + when(context.addFilter(eq("springSecurityFilterChain"), proxyCaptor.capture())) + .thenReturn(registration); + + new AbstractSecurityWebApplicationInitializer(MyRootConfiguration.class) {}.onStartup(context); + + assertProxyDefaults(proxyCaptor.getValue()); + + verify(registration).addMappingForUrlPatterns(DEFAULT_DISPATCH, false, "/*"); + verify(registration).setAsyncSupported(true); + verify(context).addListener(any(ContextLoaderListener.class)); + } + + @Configuration + static class MyRootConfiguration {} + + @Test + public void onStartupWhenEnableHttpSessionEventPublisherIsTrueThenAddsHttpSessionEventPublisher() { + ServletContext context = mock(ServletContext.class); + FilterRegistration.Dynamic registration = mock(FilterRegistration.Dynamic.class); + + ArgumentCaptor proxyCaptor = ArgumentCaptor.forClass(DelegatingFilterProxy.class); + + when(context.addFilter(eq("springSecurityFilterChain"), proxyCaptor.capture())) + .thenReturn(registration); + + new AbstractSecurityWebApplicationInitializer() { + protected boolean enableHttpSessionEventPublisher() { + return true; + } + }.onStartup(context); + + assertProxyDefaults(proxyCaptor.getValue()); + + verify(registration).addMappingForUrlPatterns(DEFAULT_DISPATCH, false, "/*"); + verify(registration).setAsyncSupported(true); + verify(context).addListener(HttpSessionEventPublisher.class.getName()); + } + + @Test + public void onStartupWhenCustomSecurityDispatcherTypesThenUses() { + ServletContext context = mock(ServletContext.class); + FilterRegistration.Dynamic registration = mock(FilterRegistration.Dynamic.class); + + ArgumentCaptor proxyCaptor = ArgumentCaptor.forClass(DelegatingFilterProxy.class); + + when(context.addFilter(eq("springSecurityFilterChain"), proxyCaptor.capture())) + .thenReturn(registration); + + new AbstractSecurityWebApplicationInitializer() { + protected EnumSet getSecurityDispatcherTypes() { + return EnumSet.of(DispatcherType.REQUEST, DispatcherType.ERROR, DispatcherType.FORWARD); + } + }.onStartup(context); + + assertProxyDefaults(proxyCaptor.getValue()); + + verify(registration).addMappingForUrlPatterns(EnumSet.of(DispatcherType.REQUEST, DispatcherType.ERROR, DispatcherType.FORWARD), false, "/*"); + verify(registration).setAsyncSupported(true); + verifyNoAddListener(context); + } + + @Test + public void onStartupWhenCustomDispatcherWebApplicationContextSuffixThenUses() { + ServletContext context = mock(ServletContext.class); + FilterRegistration.Dynamic registration = mock(FilterRegistration.Dynamic.class); + + ArgumentCaptor proxyCaptor = ArgumentCaptor.forClass(DelegatingFilterProxy.class); + + when(context.addFilter(eq("springSecurityFilterChain"), proxyCaptor.capture())) + .thenReturn(registration); + + new AbstractSecurityWebApplicationInitializer() { + protected String getDispatcherWebApplicationContextSuffix() { + return "dispatcher"; + } + }.onStartup(context); + + DelegatingFilterProxy proxy = proxyCaptor.getValue(); + assertThat(proxy.getContextAttribute()) + .isEqualTo("org.springframework.web.servlet.FrameworkServlet.CONTEXT.dispatcher"); + assertThat(proxy).hasFieldOrPropertyWithValue("targetBeanName", "springSecurityFilterChain"); + + verify(registration).addMappingForUrlPatterns(DEFAULT_DISPATCH, false, "/*"); + verify(registration).setAsyncSupported(true); + verifyNoAddListener(context); + } + + @Test + public void onStartupWhenSpringSecurityFilterChainAlreadyRegisteredThenException() { + ServletContext context = mock(ServletContext.class); + + assertThatCode(() -> + new AbstractSecurityWebApplicationInitializer() {}.onStartup(context)) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Duplicate Filter registration for 'springSecurityFilterChain'. " + + "Check to ensure the Filter is only configured once."); + } + + @Test + public void onStartupWhenInsertFiltersThenInserted() { + Filter filter1 = mock(Filter.class); + Filter filter2 = mock(Filter.class); + ServletContext context = mock(ServletContext.class); + FilterRegistration.Dynamic registration = mock(FilterRegistration.Dynamic.class); + + ArgumentCaptor proxyCaptor = ArgumentCaptor.forClass(DelegatingFilterProxy.class); + + when(context.addFilter(eq("springSecurityFilterChain"), proxyCaptor.capture())) + .thenReturn(registration); + when(context.addFilter(anyString(), eq(filter1))).thenReturn(registration); + when(context.addFilter(anyString(), eq(filter2))).thenReturn(registration); + + new AbstractSecurityWebApplicationInitializer() { + protected void afterSpringSecurityFilterChain(ServletContext servletContext) { + insertFilters(context, filter1, filter2); + } + }.onStartup(context); + + assertProxyDefaults(proxyCaptor.getValue()); + + verify(registration, times(3)).addMappingForUrlPatterns(DEFAULT_DISPATCH, false, "/*"); + verify(registration, times(3)).setAsyncSupported(true); + verifyNoAddListener(context); + verify(context).addFilter(anyString(), eq(filter1)); + verify(context).addFilter(anyString(), eq(filter2)); + } + + @Test + public void onStartupWhenDuplicateFilterInsertedThenException() { + Filter filter1 = mock(Filter.class); + ServletContext context = mock(ServletContext.class); + FilterRegistration.Dynamic registration = mock(FilterRegistration.Dynamic.class); + + ArgumentCaptor proxyCaptor = ArgumentCaptor.forClass(DelegatingFilterProxy.class); + + when(context.addFilter(eq("springSecurityFilterChain"), proxyCaptor.capture())) + .thenReturn(registration); + + assertThatCode(() -> + new AbstractSecurityWebApplicationInitializer() { + protected void afterSpringSecurityFilterChain(ServletContext servletContext) { + insertFilters(context, filter1); + } + }.onStartup(context)) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Duplicate Filter registration for 'object'. " + + "Check to ensure the Filter is only configured once."); + + assertProxyDefaults(proxyCaptor.getValue()); + + verify(registration).addMappingForUrlPatterns(DEFAULT_DISPATCH, false, "/*"); + verify(context).addFilter(anyString(), eq(filter1)); + } + + @Test + public void onStartupWhenInsertFiltersEmptyThenException() { + ServletContext context = mock(ServletContext.class); + FilterRegistration.Dynamic registration = mock(FilterRegistration.Dynamic.class); + + ArgumentCaptor proxyCaptor = ArgumentCaptor.forClass(DelegatingFilterProxy.class); + + when(context.addFilter(eq("springSecurityFilterChain"), proxyCaptor.capture())) + .thenReturn(registration); + + assertThatCode(() -> + new AbstractSecurityWebApplicationInitializer() { + protected void afterSpringSecurityFilterChain(ServletContext servletContext) { + insertFilters(context); + } + }.onStartup(context)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("filters cannot be null or empty"); + + assertProxyDefaults(proxyCaptor.getValue()); + } + + @Test + public void onStartupWhenNullFilterInsertedThenException() { + Filter filter = mock(Filter.class); + ServletContext context = mock(ServletContext.class); + FilterRegistration.Dynamic registration = mock(FilterRegistration.Dynamic.class); + + ArgumentCaptor proxyCaptor = ArgumentCaptor.forClass(DelegatingFilterProxy.class); + + when(context.addFilter(eq("springSecurityFilterChain"), proxyCaptor.capture())) + .thenReturn(registration); + when(context.addFilter(anyString(), eq(filter))).thenReturn(registration); + + assertThatCode(() -> + new AbstractSecurityWebApplicationInitializer() { + protected void afterSpringSecurityFilterChain(ServletContext servletContext) { + insertFilters(context, filter, null); + } + }.onStartup(context)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("filters cannot contain null values"); + + verify(context, times(2)).addFilter(anyString(), any(Filter.class)); + } + + @Test + public void onStartupWhenAppendFiltersThenAppended() { + Filter filter1 = mock(Filter.class); + Filter filter2 = mock(Filter.class); + ServletContext context = mock(ServletContext.class); + FilterRegistration.Dynamic registration = mock(FilterRegistration.Dynamic.class); + + ArgumentCaptor proxyCaptor = ArgumentCaptor.forClass(DelegatingFilterProxy.class); + + when(context.addFilter(eq("springSecurityFilterChain"), proxyCaptor.capture())) + .thenReturn(registration); + when(context.addFilter(anyString(), eq(filter1))).thenReturn(registration); + when(context.addFilter(anyString(), eq(filter2))).thenReturn(registration); + + new AbstractSecurityWebApplicationInitializer() { + protected void afterSpringSecurityFilterChain(ServletContext servletContext) { + appendFilters(context, filter1, filter2); + } + }.onStartup(context); + + verify(registration, times(1)).addMappingForUrlPatterns(DEFAULT_DISPATCH, false, "/*"); + verify(registration, times(2)).addMappingForUrlPatterns(DEFAULT_DISPATCH, true, "/*"); + verify(registration, times(3)).setAsyncSupported(true); + verifyNoAddListener(context); + verify(context, times(3)).addFilter(anyString(), any(Filter.class)); + } + + @Test + public void onStartupWhenDuplicateFilterAppendedThenException() { + Filter filter1 = mock(Filter.class); + ServletContext context = mock(ServletContext.class); + FilterRegistration.Dynamic registration = mock(FilterRegistration.Dynamic.class); + + ArgumentCaptor proxyCaptor = ArgumentCaptor.forClass(DelegatingFilterProxy.class); + + when(context.addFilter(eq("springSecurityFilterChain"), proxyCaptor.capture())) + .thenReturn(registration); + + assertThatCode(() -> + new AbstractSecurityWebApplicationInitializer() { + protected void afterSpringSecurityFilterChain(ServletContext servletContext) { + appendFilters(context, filter1); + } + }.onStartup(context)) + .isInstanceOf(IllegalStateException.class) + .hasMessage("Duplicate Filter registration for 'object'. " + + "Check to ensure the Filter is only configured once."); + + assertProxyDefaults(proxyCaptor.getValue()); + + verify(registration).addMappingForUrlPatterns(DEFAULT_DISPATCH, false, "/*"); + verify(context).addFilter(anyString(), eq(filter1)); + } + + + @Test + public void onStartupWhenAppendFiltersEmptyThenException() { + ServletContext context = mock(ServletContext.class); + FilterRegistration.Dynamic registration = mock(FilterRegistration.Dynamic.class); + + ArgumentCaptor proxyCaptor = ArgumentCaptor.forClass(DelegatingFilterProxy.class); + + when(context.addFilter(eq("springSecurityFilterChain"), proxyCaptor.capture())) + .thenReturn(registration); + + assertThatCode(() -> + new AbstractSecurityWebApplicationInitializer() { + protected void afterSpringSecurityFilterChain(ServletContext servletContext) { + appendFilters(context); + } + }.onStartup(context)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("filters cannot be null or empty"); + + assertProxyDefaults(proxyCaptor.getValue()); + } + + @Test + public void onStartupWhenNullFilterAppendedThenException() { + Filter filter = mock(Filter.class); + ServletContext context = mock(ServletContext.class); + FilterRegistration.Dynamic registration = mock(FilterRegistration.Dynamic.class); + + ArgumentCaptor proxyCaptor = ArgumentCaptor.forClass(DelegatingFilterProxy.class); + + when(context.addFilter(eq("springSecurityFilterChain"), proxyCaptor.capture())) + .thenReturn(registration); + when(context.addFilter(anyString(), eq(filter))).thenReturn(registration); + + assertThatCode(() -> + new AbstractSecurityWebApplicationInitializer() { + protected void afterSpringSecurityFilterChain(ServletContext servletContext) { + appendFilters(context, filter, null); + } + }.onStartup(context)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("filters cannot contain null values"); + + verify(context, times(2)).addFilter(anyString(), any(Filter.class)); + } + + @Test + public void onStartupWhenDefaultsThenSessionTrackingModes() { + ServletContext context = mock(ServletContext.class); + FilterRegistration.Dynamic registration = mock(FilterRegistration.Dynamic.class); + + ArgumentCaptor proxyCaptor = ArgumentCaptor.forClass(DelegatingFilterProxy.class); + when(context.addFilter(eq("springSecurityFilterChain"), proxyCaptor.capture())) + .thenReturn(registration); + + ArgumentCaptor> modesCaptor = ArgumentCaptor + .forClass(new HashSet(){}.getClass()); + doNothing().when(context).setSessionTrackingModes(modesCaptor.capture()); + + new AbstractSecurityWebApplicationInitializer() { }.onStartup(context); + + assertProxyDefaults(proxyCaptor.getValue()); + + Set modes = modesCaptor.getValue(); + assertThat(modes).hasSize(1); + assertThat(modes).containsExactly(SessionTrackingMode.COOKIE); + } + + @Test + public void onStartupWhenSessionTrackingModesConfiguredThenUsed() { + ServletContext context = mock(ServletContext.class); + FilterRegistration.Dynamic registration = mock(FilterRegistration.Dynamic.class); + + ArgumentCaptor proxyCaptor = ArgumentCaptor.forClass(DelegatingFilterProxy.class); + when(context.addFilter(eq("springSecurityFilterChain"), proxyCaptor.capture())) + .thenReturn(registration); + + ArgumentCaptor> modesCaptor = ArgumentCaptor + .forClass(new HashSet(){}.getClass()); + doNothing().when(context).setSessionTrackingModes(modesCaptor.capture()); + + new AbstractSecurityWebApplicationInitializer() { + @Override + public Set getSessionTrackingModes() { + return Collections.singleton(SessionTrackingMode.SSL); + } + }.onStartup(context); + + assertProxyDefaults(proxyCaptor.getValue()); + + Set modes = modesCaptor.getValue(); + assertThat(modes).hasSize(1); + assertThat(modes).containsExactly(SessionTrackingMode.SSL); + } + + @Test + public void defaultFilterNameEqualsSpringSecurityFilterChain() { + assertThat(AbstractSecurityWebApplicationInitializer.DEFAULT_FILTER_NAME) + .isEqualTo("springSecurityFilterChain"); + } + + private static void verifyNoAddListener(ServletContext context) { + verify(context, times(0)).addListener(anyString()); + verify(context, times(0)).addListener(any(EventListener.class)); + verify(context, times(0)).addListener(any(Class.class)); + } + + private static void assertProxyDefaults(DelegatingFilterProxy proxy) { + assertThat(proxy.getContextAttribute()).isNull(); + assertThat(proxy).hasFieldOrPropertyWithValue("targetBeanName", "springSecurityFilterChain"); + } +}