diff --git a/spring-boot/src/main/java/org/springframework/boot/context/embedded/EmbeddedWebApplicationContext.java b/spring-boot/src/main/java/org/springframework/boot/context/embedded/EmbeddedWebApplicationContext.java index 22528ce6b33..268b2dee11b 100644 --- a/spring-boot/src/main/java/org/springframework/boot/context/embedded/EmbeddedWebApplicationContext.java +++ b/spring-boot/src/main/java/org/springframework/boot/context/embedded/EmbeddedWebApplicationContext.java @@ -225,27 +225,35 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext */ protected Collection getServletContextInitializerBeans() { - Set initializers = new LinkedHashSet(); + List filters = new ArrayList(); + List servlets = new ArrayList(); + List listeners = new ArrayList(); + List other = new ArrayList(); Set servletRegistrations = new LinkedHashSet(); Set filterRegistrations = new LinkedHashSet(); Set listenerRegistrations = new LinkedHashSet(); for (Entry initializerBean : getOrderedBeansOfType(ServletContextInitializer.class)) { ServletContextInitializer initializer = initializerBean.getValue(); - initializers.add(initializer); if (initializer instanceof ServletRegistrationBean) { + servlets.add(initializer); ServletRegistrationBean servlet = (ServletRegistrationBean) initializer; servletRegistrations.add(servlet.getServlet()); } - if (initializer instanceof FilterRegistrationBean) { + else if (initializer instanceof FilterRegistrationBean) { + filters.add(initializer); FilterRegistrationBean filter = (FilterRegistrationBean) initializer; filterRegistrations.add(filter.getFilter()); } - if (initializer instanceof ServletListenerRegistrationBean) { + else if (initializer instanceof ServletListenerRegistrationBean) { + listeners.add(initializer); listenerRegistrations .add(((ServletListenerRegistrationBean) initializer) .getListener()); } + else { + other.add(initializer); + } } List> servletBeans = getOrderedBeansOfType(Servlet.class); @@ -261,7 +269,9 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext servlet, url); registration.setName(name); registration.setMultipartConfig(getMultipartConfig()); - initializers.add(registration); + registration.setOrder(CustomOrderAwareComparator.INSTANCE + .getOrder(servlet)); + servlets.add(registration); } } @@ -271,7 +281,9 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext if (!filterRegistrations.contains(filter)) { FilterRegistrationBean registration = new FilterRegistrationBean(filter); registration.setName(name); - initializers.add(registration); + registration.setOrder(CustomOrderAwareComparator.INSTANCE + .getOrder(filter)); + filters.add(registration); } } @@ -285,12 +297,23 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext ServletListenerRegistrationBean registration = new ServletListenerRegistrationBean( listener); registration.setName(name); - initializers.add(registration); + registration.setOrder(CustomOrderAwareComparator.INSTANCE + .getOrder(listener)); + listeners.add(registration); } } } - - return initializers; + AnnotationAwareOrderComparator.sort(filters); + AnnotationAwareOrderComparator.sort(servlets); + AnnotationAwareOrderComparator.sort(listeners); + AnnotationAwareOrderComparator.sort(other); + + List list = new ArrayList( + filters); + list.addAll(servlets); + list.addAll(listeners); + list.addAll(other); + return list; } private MultipartConfigElement getMultipartConfig() { @@ -425,4 +448,15 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext return this.embeddedServletContainer; } + private static class CustomOrderAwareComparator extends + AnnotationAwareOrderComparator { + + public static CustomOrderAwareComparator INSTANCE = new CustomOrderAwareComparator(); + + @Override + protected int getOrder(Object obj) { + return super.getOrder(obj); + } + } + } diff --git a/spring-boot/src/test/java/org/springframework/boot/context/embedded/EmbeddedWebApplicationContextTests.java b/spring-boot/src/test/java/org/springframework/boot/context/embedded/EmbeddedWebApplicationContextTests.java index 007e48e840a..29d54dbbb47 100644 --- a/spring-boot/src/test/java/org/springframework/boot/context/embedded/EmbeddedWebApplicationContextTests.java +++ b/spring-boot/src/test/java/org/springframework/boot/context/embedded/EmbeddedWebApplicationContextTests.java @@ -16,13 +16,18 @@ package org.springframework.boot.context.embedded; +import java.io.IOException; import java.lang.reflect.Field; import java.util.Properties; import javax.servlet.Filter; +import javax.servlet.FilterChain; import javax.servlet.Servlet; import javax.servlet.ServletContext; import javax.servlet.ServletContextListener; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; import org.junit.After; import org.junit.Before; @@ -30,6 +35,7 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; import org.mockito.InOrder; +import org.mockito.Mockito; import org.springframework.beans.MutablePropertyValues; import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.ConstructorArgumentValues; @@ -39,9 +45,11 @@ import org.springframework.context.ApplicationListener; import org.springframework.context.support.AbstractApplicationContext; import org.springframework.context.support.PropertySourcesPlaceholderConfigurer; import org.springframework.core.Ordered; +import org.springframework.core.annotation.Order; import org.springframework.web.context.ServletContextAware; import org.springframework.web.context.WebApplicationContext; import org.springframework.web.context.request.SessionScope; +import org.springframework.web.filter.GenericFilterBean; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; @@ -192,6 +200,23 @@ public class EmbeddedWebApplicationContextTests { verify(escf.getRegisteredServlet(0).getRegistration()).addMapping("/"); } + @Test + public void orderedBeanInsertedCorrectly() throws Exception { + addEmbeddedServletContainerFactoryBean(); + OrderedFilter filter = new OrderedFilter(); + this.context.registerBeanDefinition("filterBean", beanDefinition(filter)); + FilterRegistrationBean registration = new FilterRegistrationBean(); + registration.setFilter(Mockito.mock(Filter.class)); + registration.setOrder(100); + this.context.registerBeanDefinition("filterRegistrationBean", + beanDefinition(registration)); + this.context.refresh(); + MockEmbeddedServletContainerFactory escf = getEmbeddedServletContainerFactory(); + verify(escf.getServletContext()).addFilter("filterBean", filter); + verify(escf.getServletContext()).addFilter("object", registration.getFilter()); + assertEquals(filter, escf.getRegisteredFilter(0).getFilter()); + } + @Test public void multipleServletBeans() throws Exception { addEmbeddedServletContainerFactoryBean(); @@ -422,4 +447,14 @@ public class EmbeddedWebApplicationContextTests { } } + + @Order(10) + protected static class OrderedFilter extends GenericFilterBean { + + @Override + public void doFilter(ServletRequest request, ServletResponse response, + FilterChain chain) throws IOException, ServletException { + } + + } }