diff --git a/test/src/main/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListener.java b/test/src/main/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListener.java index 3cc8a505d3..5e6328dc49 100644 --- a/test/src/main/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListener.java +++ b/test/src/main/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListener.java @@ -71,7 +71,7 @@ public class WithSecurityContextTestExecutionListener extends if (withSecurityContext != null) { WithSecurityContextFactory factory = createFactory(withSecurityContext, context); Class type = (Class) GenericTypeResolver.resolveTypeArgument(factory.getClass(), WithSecurityContextFactory.class); - Annotation annotation = AnnotationUtils.findAnnotation(annotated, type); + Annotation annotation = findAnnotation(annotated, type); try { return factory.createSecurityContext(annotation); } @@ -83,6 +83,23 @@ public class WithSecurityContextTestExecutionListener extends return null; } + private Annotation findAnnotation(AnnotatedElement annotated, + Class type) { + Annotation findAnnotation = AnnotationUtils.findAnnotation(annotated, type); + if (findAnnotation != null) { + return findAnnotation; + } + Annotation[] allAnnotations = AnnotationUtils.getAnnotations(annotated); + for (Annotation annotationToTest : allAnnotations) { + WithSecurityContext withSecurityContext = AnnotationUtils.findAnnotation( + annotationToTest.annotationType(), WithSecurityContext.class); + if (withSecurityContext != null) { + return annotationToTest; + } + } + return null; + } + private WithSecurityContextFactory createFactory( WithSecurityContext withSecurityContext, TestContext testContext) { Class> clazz = withSecurityContext diff --git a/test/src/test/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListenerTests.java b/test/src/test/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListenerTests.java index 029b860401..a468115321 100644 --- a/test/src/test/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListenerTests.java +++ b/test/src/test/java/org/springframework/security/test/context/support/WithSecurityContextTestExecutionListenerTests.java @@ -15,16 +15,29 @@ */ package org.springframework.security.test.context.support; -import static org.assertj.core.api.Assertions.*; - +import java.lang.annotation.Annotation; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.reflect.Method; import java.util.ArrayList; import java.util.List; +import org.junit.After; import org.junit.Before; import org.junit.Test; + import org.springframework.core.annotation.AnnotationAwareOrderComparator; +import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.test.context.TestContext; import org.springframework.test.context.TestExecutionListener; import org.springframework.test.context.support.DependencyInjectionTestExecutionListener; +import org.springframework.util.ReflectionUtils; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; /** * @author Rob Winch @@ -35,7 +48,12 @@ public class WithSecurityContextTestExecutionListenerTests { @Before public void setup() { - listener = new WithSecurityContextTestExecutionListener(); + this.listener = new WithSecurityContextTestExecutionListener(); + } + + @After + public void cleanup() { + SecurityContextHolder.clearContext(); } // SEC-2709 @@ -45,10 +63,48 @@ public class WithSecurityContextTestExecutionListenerTests { List listeners = new ArrayList(); listeners.add(otherListener); - listeners.add(listener); + listeners.add(this.listener); AnnotationAwareOrderComparator.sort(listeners); - assertThat(listeners).containsSequence(listener, otherListener); + assertThat(listeners).containsSequence(this.listener, otherListener); + } + + @Test + // gh-3837 + public void handlesGenericAnnotation() throws Exception { + Method method = ReflectionUtils.findMethod( + WithSecurityContextTestExecutionListenerTests.class, + "handlesGenericAnnotationTestMethod"); + TestContext testContext = mock(TestContext.class); + when(testContext.getTestMethod()).thenReturn(method); + when(testContext.getApplicationContext()) + .thenThrow(new IllegalStateException("")); + + this.listener.beforeTestMethod(testContext); + + assertThat(SecurityContextHolder.getContext().getAuthentication().getPrincipal()) + .isInstanceOf(WithSuperClassWithSecurityContext.class); + } + + @WithSuperClassWithSecurityContext + public void handlesGenericAnnotationTestMethod() { + } + + @Retention(RetentionPolicy.RUNTIME) + @WithSecurityContext(factory = SuperClassWithSecurityContextFactory.class) + @interface WithSuperClassWithSecurityContext { + String username() default "WithSuperClassWithSecurityContext"; + } + + static class SuperClassWithSecurityContextFactory + implements WithSecurityContextFactory { + + @Override + public SecurityContext createSecurityContext(Annotation annotation) { + SecurityContext context = SecurityContextHolder.createEmptyContext(); + context.setAuthentication(new TestingAuthenticationToken(annotation, "NA")); + return context; + } } }