From da8d50f91fed1470c69baa0cbaeeffc4b31a0e8b Mon Sep 17 00:00:00 2001 From: Sam Brannen Date: Thu, 3 Aug 2017 23:19:57 +0300 Subject: [PATCH] Revise SpringExtension based on recent changes in JUnit Jupiter This commit revises the implementation of the SpringExtension to use the getRequired*() methods in the ExtensionContext which are now built into JUnit Jupiter thanks to inspiration from the initial "convenience" methods implemented here. --- .../junit/jupiter/SpringExtension.java | 60 +++++-------------- 1 file changed, 14 insertions(+), 46 deletions(-) diff --git a/spring-test/src/main/java/org/springframework/test/context/junit/jupiter/SpringExtension.java b/spring-test/src/main/java/org/springframework/test/context/junit/jupiter/SpringExtension.java index 1af93a39c5e..e0f365cab10 100644 --- a/spring-test/src/main/java/org/springframework/test/context/junit/jupiter/SpringExtension.java +++ b/spring-test/src/main/java/org/springframework/test/context/junit/jupiter/SpringExtension.java @@ -84,7 +84,7 @@ public class SpringExtension implements BeforeAllCallback, AfterAllCallback, Tes getTestContextManager(context).afterTestClass(); } finally { - context.getStore(NAMESPACE).remove(getRequiredTestClass(context)); + getStore(context).remove(context.getRequiredTestClass()); } } @@ -101,8 +101,8 @@ public class SpringExtension implements BeforeAllCallback, AfterAllCallback, Tes */ @Override public void beforeEach(ExtensionContext context) throws Exception { - Object testInstance = getRequiredTestInstance(context); - Method testMethod = getRequiredTestMethod(context); + Object testInstance = context.getRequiredTestInstance(); + Method testMethod = context.getRequiredTestMethod(); getTestContextManager(context).beforeTestMethod(testInstance, testMethod); } @@ -111,8 +111,8 @@ public class SpringExtension implements BeforeAllCallback, AfterAllCallback, Tes */ @Override public void beforeTestExecution(ExtensionContext context) throws Exception { - Object testInstance = getRequiredTestInstance(context); - Method testMethod = getRequiredTestMethod(context); + Object testInstance = context.getRequiredTestInstance(); + Method testMethod = context.getRequiredTestMethod(); getTestContextManager(context).beforeTestExecution(testInstance, testMethod); } @@ -121,8 +121,8 @@ public class SpringExtension implements BeforeAllCallback, AfterAllCallback, Tes */ @Override public void afterTestExecution(ExtensionContext context) throws Exception { - Object testInstance = getRequiredTestInstance(context); - Method testMethod = getRequiredTestMethod(context); + Object testInstance = context.getRequiredTestInstance(); + Method testMethod = context.getRequiredTestMethod(); Throwable testException = context.getExecutionException().orElse(null); getTestContextManager(context).afterTestExecution(testInstance, testMethod, testException); } @@ -132,8 +132,8 @@ public class SpringExtension implements BeforeAllCallback, AfterAllCallback, Tes */ @Override public void afterEach(ExtensionContext context) throws Exception { - Object testInstance = getRequiredTestInstance(context); - Method testMethod = getRequiredTestMethod(context); + Object testInstance = context.getRequiredTestInstance(); + Method testMethod = context.getRequiredTestMethod(); Throwable testException = context.getExecutionException().orElse(null); getTestContextManager(context).afterTestMethod(testInstance, testMethod, testException); } @@ -171,7 +171,7 @@ public class SpringExtension implements BeforeAllCallback, AfterAllCallback, Tes @Nullable public Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext) { Parameter parameter = parameterContext.getParameter(); - Class testClass = getRequiredTestClass(extensionContext); + Class testClass = extensionContext.getRequiredTestClass(); ApplicationContext applicationContext = getApplicationContext(extensionContext); return ParameterAutowireUtils.resolveDependency(parameter, testClass, applicationContext); } @@ -194,45 +194,13 @@ public class SpringExtension implements BeforeAllCallback, AfterAllCallback, Tes */ private static TestContextManager getTestContextManager(ExtensionContext context) { Assert.notNull(context, "ExtensionContext must not be null"); - Class testClass = getRequiredTestClass(context); - Store store = context.getStore(NAMESPACE); + Class testClass = context.getRequiredTestClass(); + Store store = getStore(context); return store.getOrComputeIfAbsent(testClass, TestContextManager::new, TestContextManager.class); } - /** - * Get the test class associated with the supplied {@code ExtensionContext}. - * @return the test class - * @throws IllegalStateException if the extension context does not contain - * a test class - */ - private static Class getRequiredTestClass(ExtensionContext context) throws IllegalStateException { - Assert.notNull(context, "ExtensionContext must not be null"); - return context.getTestClass().orElseThrow( - () -> new IllegalStateException("JUnit failed to supply the test class in the ExtensionContext")); - } - - /** - * Get the test instance associated with the supplied {@code ExtensionContext}. - * @return the test instance - * @throws IllegalStateException if the extension context does not contain - * a test instance - */ - private static Object getRequiredTestInstance(ExtensionContext context) throws IllegalStateException { - Assert.notNull(context, "ExtensionContext must not be null"); - return context.getTestInstance().orElseThrow( - () -> new IllegalStateException("JUnit failed to supply the test instance in the ExtensionContext")); - } - - /** - * Get the test method associated with the supplied {@code ExtensionContext}. - * @return the test method - * @throws IllegalStateException if the extension context does not contain - * a test method - */ - private static Method getRequiredTestMethod(ExtensionContext context) throws IllegalStateException { - Assert.notNull(context, "ExtensionContext must not be null"); - return context.getTestMethod().orElseThrow( - () -> new IllegalStateException("JUnit failed to supply the test method in the ExtensionContext")); + private static Store getStore(ExtensionContext context) { + return context.getRoot().getStore(NAMESPACE); } }