diff --git a/spring-test/src/main/java/org/springframework/test/context/TestContext.java b/spring-test/src/main/java/org/springframework/test/context/TestContext.java
index cce8366ece8..3f632534178 100644
--- a/spring-test/src/main/java/org/springframework/test/context/TestContext.java
+++ b/spring-test/src/main/java/org/springframework/test/context/TestContext.java
@@ -42,6 +42,7 @@ import org.springframework.test.annotation.DirtiesContext.HierarchyMode;
* override {@link #setMethodInvoker(MethodInvoker)} and {@link #getMethodInvoker()}.
*
* @author Sam Brannen
+ * @author Andreas Ahlenstorf
* @since 2.5
* @see TestContextManager
* @see TestExecutionListener
@@ -110,6 +111,25 @@ public interface TestContext extends AttributeAccessor, Serializable {
*/
Object getTestInstance();
+ /**
+ * Tests whether a test method is part of this test context. Returns
+ * {@code true} if this context has a current test method, {@code false}
+ * otherwise.
+ *
+ *
The default implementation of this method always returns {@code false}.
+ * Custom {@code TestContext} implementations are therefore highly encouraged
+ * to override this method with a more meaningful implementation. Note that
+ * the standard {@code TestContext} implementation in Spring overrides this
+ * method appropriately.
+ * @return {@code true} if the test execution has already entered a test
+ * method
+ * @since 6.1
+ * @see #getTestMethod()
+ */
+ default boolean hasTestMethod() {
+ return false;
+ }
+
/**
* Get the current {@linkplain Method test method} for this test context.
*
Note: this is a mutable property.
diff --git a/spring-test/src/main/java/org/springframework/test/context/jdbc/Sql.java b/spring-test/src/main/java/org/springframework/test/context/jdbc/Sql.java
index 943b9edff86..6ab67c8cfa5 100644
--- a/spring-test/src/main/java/org/springframework/test/context/jdbc/Sql.java
+++ b/spring-test/src/main/java/org/springframework/test/context/jdbc/Sql.java
@@ -33,6 +33,11 @@ import org.springframework.core.annotation.AliasFor;
*
*
Method-level declarations override class-level declarations by default,
* but this behavior can be configured via {@link SqlMergeMode @SqlMergeMode}.
+ * However, this does not apply to class-level declarations that use
+ * {@link ExecutionPhase#BEFORE_TEST_CLASS} or
+ * {@link ExecutionPhase#AFTER_TEST_CLASS}. Such declarations are retained and
+ * scripts and statements are executed once per class in addition to any
+ * method-level annotations.
*
*
Script execution is performed by the {@link SqlScriptsTestExecutionListener},
* which is enabled by default.
@@ -61,6 +66,7 @@ import org.springframework.core.annotation.AliasFor;
* modules as well as their transitive dependencies to be present on the classpath.
*
* @author Sam Brannen
+ * @author Andreas Ahlenstorf
* @since 4.1
* @see SqlConfig
* @see SqlMergeMode
@@ -161,6 +167,18 @@ public @interface Sql {
*/
enum ExecutionPhase {
+ /**
+ * The configured SQL scripts and statements will be executed
+ * once before any test method is run.
+ */
+ BEFORE_TEST_CLASS,
+
+ /**
+ * The configured SQL scripts and statements will be executed
+ * once after any test method is run.
+ */
+ AFTER_TEST_CLASS,
+
/**
* The configured SQL scripts and statements will be executed
* before the corresponding test method.
diff --git a/spring-test/src/main/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListener.java b/spring-test/src/main/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListener.java
index c7e51561f8d..90d98db5c4c 100644
--- a/spring-test/src/main/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListener.java
+++ b/spring-test/src/main/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListener.java
@@ -67,10 +67,17 @@ import static org.springframework.util.ResourceUtils.CLASSPATH_URL_PREFIX;
* {@link Sql#scripts scripts} and inlined {@link Sql#statements statements}
* configured via the {@link Sql @Sql} annotation.
*
- *
Scripts and inlined statements will be executed {@linkplain #beforeTestMethod(TestContext) before}
- * or {@linkplain #afterTestMethod(TestContext) after} execution of the corresponding
- * {@linkplain java.lang.reflect.Method test method}, depending on the configured
- * value of the {@link Sql#executionPhase executionPhase} flag.
+ *
Class-level annotations that are constrained to a class-level execution
+ * phase ({@link ExecutionPhase#BEFORE_TEST_CLASS} or
+ * {@link ExecutionPhase#AFTER_TEST_CLASS}) will be run
+ * {@linkplain #beforeTestClass(TestContext) once before all test methods} or
+ * {@linkplain #afterTestMethod(TestContext) once after all test methods},
+ * respectively. All other scripts and inlined statements will be executed
+ * {@linkplain #beforeTestMethod(TestContext) before} or
+ * {@linkplain #afterTestMethod(TestContext) after} execution of the
+ * corresponding {@linkplain java.lang.reflect.Method test method}, depending
+ * on the configured value of the {@link Sql#executionPhase executionPhase}
+ * flag.
*
*
Scripts and inlined statements will be executed without a transaction,
* within an existing Spring-managed transaction, or within an isolated transaction,
@@ -98,6 +105,7 @@ import static org.springframework.util.ResourceUtils.CLASSPATH_URL_PREFIX;
*
* @author Sam Brannen
* @author Dmitry Semukhin
+ * @author Andreas Ahlenstorf
* @since 4.1
* @see Sql
* @see SqlConfig
@@ -126,6 +134,26 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
return 5000;
}
+ /**
+ * Execute SQL scripts configured via {@link Sql @Sql} for the supplied
+ * {@link TestContext} once per test class before any test method
+ * is run.
+ */
+ @Override
+ public void beforeTestClass(TestContext testContext) throws Exception {
+ executeBeforeOrAfterClassSqlScripts(testContext, ExecutionPhase.BEFORE_TEST_CLASS);
+ }
+
+ /**
+ * Execute SQL scripts configured via {@link Sql @Sql} for the supplied
+ * {@link TestContext} once per test class after all test methods
+ * have been run.
+ */
+ @Override
+ public void afterTestClass(TestContext testContext) throws Exception {
+ executeBeforeOrAfterClassSqlScripts(testContext, ExecutionPhase.AFTER_TEST_CLASS);
+ }
+
/**
* Execute SQL scripts configured via {@link Sql @Sql} for the supplied
* {@link TestContext} before the current test method.
@@ -159,6 +187,17 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
registerClasspathResources(getScripts(sql, testClass, testMethod, false), runtimeHints, classLoader)));
}
+ /**
+ * Execute class-level SQL scripts configured via {@link Sql @Sql} for the
+ * supplied {@link TestContext} and the execution phases
+ * {@link ExecutionPhase#BEFORE_TEST_CLASS} and
+ * {@link ExecutionPhase#AFTER_TEST_CLASS}.
+ */
+ private void executeBeforeOrAfterClassSqlScripts(TestContext testContext, ExecutionPhase executionPhase) {
+ Class> testClass = testContext.getTestClass();
+ executeSqlScripts(getSqlAnnotationsFor(testClass), testContext, executionPhase, true);
+ }
+
/**
* Execute SQL scripts configured via {@link Sql @Sql} for the supplied
* {@link TestContext} and {@link ExecutionPhase}.
@@ -246,6 +285,9 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
private void executeSqlScripts(
Sql sql, ExecutionPhase executionPhase, TestContext testContext, boolean classLevel) {
+ Assert.isTrue(classLevel || isValidMethodLevelPhase(sql.executionPhase()),
+ () -> "%s cannot be used on methods".formatted(sql.executionPhase()));
+
if (executionPhase != sql.executionPhase()) {
return;
}
@@ -260,7 +302,12 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
.formatted(executionPhase, testContext.getTestClass().getName()));
}
- String[] scripts = getScripts(sql, testContext.getTestClass(), testContext.getTestMethod(), classLevel);
+ Method testMethod = null;
+ if (testContext.hasTestMethod()) {
+ testMethod = testContext.getTestMethod();
+ }
+
+ String[] scripts = getScripts(sql, testContext.getTestClass(), testMethod, classLevel);
List scriptResources = TestContextResourceUtils.convertToResourceList(
testContext.getApplicationContext(), scripts);
for (String stmt : sql.statements()) {
@@ -354,7 +401,7 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
return null;
}
- private String[] getScripts(Sql sql, Class> testClass, Method testMethod, boolean classLevel) {
+ private String[] getScripts(Sql sql, Class> testClass, @Nullable Method testMethod, boolean classLevel) {
String[] scripts = sql.scripts();
if (ObjectUtils.isEmpty(scripts) && ObjectUtils.isEmpty(sql.statements())) {
scripts = new String[] {detectDefaultScript(testClass, testMethod, classLevel)};
@@ -366,7 +413,9 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
* Detect a default SQL script by implementing the algorithm defined in
* {@link Sql#scripts}.
*/
- private String detectDefaultScript(Class> testClass, Method testMethod, boolean classLevel) {
+ private String detectDefaultScript(Class> testClass, @Nullable Method testMethod, boolean classLevel) {
+ Assert.state(classLevel || testMethod != null, "Method-level @Sql requires a testMethod");
+
String elementType = (classLevel ? "class" : "method");
String elementName = (classLevel ? testClass.getName() : testMethod.toString());
@@ -407,4 +456,9 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
.forEach(runtimeHints.resources()::registerResource);
}
+ private static boolean isValidMethodLevelPhase(ExecutionPhase executionPhase) {
+ // Class-level phases cannot be used on methods.
+ return executionPhase == ExecutionPhase.BEFORE_TEST_METHOD ||
+ executionPhase == ExecutionPhase.AFTER_TEST_METHOD;
+ }
}
diff --git a/spring-test/src/main/java/org/springframework/test/context/support/DefaultTestContext.java b/spring-test/src/main/java/org/springframework/test/context/support/DefaultTestContext.java
index 03b6a25827e..d95b2457e62 100644
--- a/spring-test/src/main/java/org/springframework/test/context/support/DefaultTestContext.java
+++ b/spring-test/src/main/java/org/springframework/test/context/support/DefaultTestContext.java
@@ -41,6 +41,7 @@ import org.springframework.util.StringUtils;
* @author Sam Brannen
* @author Juergen Hoeller
* @author Rob Harrop
+ * @author Andreas Ahlenstorf
* @since 4.0
*/
@SuppressWarnings("serial")
@@ -166,6 +167,11 @@ public class DefaultTestContext implements TestContext {
return testInstance;
}
+ @Override
+ public boolean hasTestMethod() {
+ return this.testMethod != null;
+ }
+
@Override
public final Method getTestMethod() {
Method testMethod = this.testMethod;
diff --git a/spring-test/src/main/java/org/springframework/test/context/transaction/TestContextTransactionUtils.java b/spring-test/src/main/java/org/springframework/test/context/transaction/TestContextTransactionUtils.java
index 38301afc42e..addfd8ded26 100644
--- a/spring-test/src/main/java/org/springframework/test/context/transaction/TestContextTransactionUtils.java
+++ b/spring-test/src/main/java/org/springframework/test/context/transaction/TestContextTransactionUtils.java
@@ -1,5 +1,5 @@
/*
- * Copyright 2002-2022 the original author or authors.
+ * Copyright 2002-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -46,6 +46,7 @@ import org.springframework.util.StringUtils;
*
* @author Sam Brannen
* @author Juergen Hoeller
+ * @author Andreas Ahlenstorf
* @since 4.1
*/
public abstract class TestContextTransactionUtils {
@@ -227,7 +228,8 @@ public abstract class TestContextTransactionUtils {
/**
* Create a delegating {@link TransactionAttribute} for the supplied target
* {@link TransactionAttribute} and {@link TestContext}, using the names of
- * the test class and test method to build the name of the transaction.
+ * the test class and test method (if available) to build the name of the
+ * transaction.
* @param testContext the {@code TestContext} upon which to base the name
* @param targetAttribute the {@code TransactionAttribute} to delegate to
* @return the delegating {@code TransactionAttribute}
@@ -248,7 +250,13 @@ public abstract class TestContextTransactionUtils {
public TestContextTransactionAttribute(TransactionAttribute targetAttribute, TestContext testContext) {
super(targetAttribute);
- this.name = ClassUtils.getQualifiedMethodName(testContext.getTestMethod(), testContext.getTestClass());
+
+ if (testContext.hasTestMethod()) {
+ this.name = ClassUtils.getQualifiedMethodName(testContext.getTestMethod(), testContext.getTestClass());
+ }
+ else {
+ this.name = testContext.getTestClass().getName();
+ }
}
@Override
diff --git a/spring-test/src/test/java/org/springframework/test/context/jdbc/AfterTestClassSqlScriptsTests.java b/spring-test/src/test/java/org/springframework/test/context/jdbc/AfterTestClassSqlScriptsTests.java
new file mode 100644
index 00000000000..1d6633288ea
--- /dev/null
+++ b/spring-test/src/test/java/org/springframework/test/context/jdbc/AfterTestClassSqlScriptsTests.java
@@ -0,0 +1,87 @@
+/*
+ * Copyright 2002-2023 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.test.context.jdbc;
+
+import javax.sql.DataSource;
+
+import org.junit.jupiter.api.Order;
+import org.junit.jupiter.api.Test;
+
+import org.springframework.core.Ordered;
+import org.springframework.jdbc.BadSqlGrammarException;
+import org.springframework.jdbc.core.JdbcTemplate;
+import org.springframework.test.annotation.Commit;
+import org.springframework.test.annotation.DirtiesContext;
+import org.springframework.test.context.TestContext;
+import org.springframework.test.context.TestExecutionListener;
+import org.springframework.test.context.TestExecutionListeners;
+import org.springframework.test.context.junit.jupiter.SpringJUnitConfig;
+import org.springframework.test.context.transaction.TestContextTransactionUtils;
+
+import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+
+/**
+ * Verifies that {@link Sql @Sql} with {@link Sql.ExecutionPhase#AFTER_TEST_CLASS} is run after all tests in the class
+ * have been run.
+ *
+ * @author Andreas Ahlenstorf
+ * @since 6.1
+ */
+@SpringJUnitConfig(PopulatedSchemaDatabaseConfig.class)
+@DirtiesContext(classMode = DirtiesContext.ClassMode.AFTER_CLASS)
+@Sql(value = {"drop-schema.sql"}, executionPhase = Sql.ExecutionPhase.AFTER_TEST_CLASS)
+@TestExecutionListeners(
+ value = AfterTestClassSqlScriptsTests.VerifyTestExecutionListener.class,
+ mergeMode = TestExecutionListeners.MergeMode.MERGE_WITH_DEFAULTS
+)
+class AfterTestClassSqlScriptsTests extends AbstractTransactionalTests {
+
+ @Test
+ @Order(1)
+ @Sql(scripts = "data-add-catbert.sql")
+ @Commit
+ void databaseHasBeenInitialized() {
+ assertUsers("Catbert");
+ }
+
+ @Test
+ @Order(2)
+ @Sql(scripts = "data-add-dogbert.sql")
+ @Commit
+ void databaseIsNotWipedBetweenTests() {
+ assertUsers("Catbert", "Dogbert");
+ }
+
+ static class VerifyTestExecutionListener implements TestExecutionListener, Ordered {
+
+ @Override
+ public void afterTestClass(TestContext testContext) throws Exception {
+ DataSource dataSource = TestContextTransactionUtils.retrieveDataSource(testContext, null);
+ JdbcTemplate jdbcTemplate = new JdbcTemplate(dataSource);
+
+ assertThatExceptionOfType(BadSqlGrammarException.class)
+ .isThrownBy(() -> jdbcTemplate.queryForList("SELECT name FROM user", String.class));
+ }
+
+ @Override
+ public int getOrder() {
+ // Must run before DirtiesContextTestExecutionListener. Otherwise, the old data source will be removed and
+ // replaced with a new one.
+ return 3001;
+ }
+ }
+}
diff --git a/spring-test/src/test/java/org/springframework/test/context/jdbc/BeforeTestClassSqlScriptsTests.java b/spring-test/src/test/java/org/springframework/test/context/jdbc/BeforeTestClassSqlScriptsTests.java
new file mode 100644
index 00000000000..acaae772cd7
--- /dev/null
+++ b/spring-test/src/test/java/org/springframework/test/context/jdbc/BeforeTestClassSqlScriptsTests.java
@@ -0,0 +1,58 @@
+/*
+ * Copyright 2002-2023 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.springframework.test.context.jdbc;
+
+import org.junit.jupiter.api.Test;
+
+import org.springframework.test.annotation.DirtiesContext;
+import org.springframework.test.context.junit.jupiter.SpringJUnitConfig;
+
+import static org.springframework.test.context.jdbc.SqlMergeMode.MergeMode.MERGE;
+import static org.springframework.test.context.jdbc.SqlMergeMode.MergeMode.OVERRIDE;
+
+/**
+ * Verifies that {@link Sql @Sql} with {@link Sql.ExecutionPhase#BEFORE_TEST_CLASS} is run before all tests in the class
+ * have been run.
+ *
+ * @author Andreas Ahlenstorf
+ * @since 6.1
+ */
+@SpringJUnitConfig(classes = EmptyDatabaseConfig.class)
+@DirtiesContext
+@Sql(value = {"schema.sql", "data-add-catbert.sql"}, executionPhase = Sql.ExecutionPhase.BEFORE_TEST_CLASS)
+class BeforeTestClassSqlScriptsTests extends AbstractTransactionalTests {
+
+ @Test
+ void classLevelScriptsHaveBeenRun() {
+ assertUsers("Catbert");
+ }
+
+ @Test
+ @Sql("data-add-dogbert.sql")
+ @SqlMergeMode(MERGE)
+ void mergeDoesNotAffectClassLevelPhase() {
+ assertUsers("Catbert", "Dogbert");
+ }
+
+ @Test
+ @Sql({"data-add-dogbert.sql"})
+ @SqlMergeMode(OVERRIDE)
+ void overrideDoesNotAffectClassLevelPhase() {
+ assertUsers("Dogbert", "Catbert");
+ }
+
+}
diff --git a/spring-test/src/test/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListenerTests.java b/spring-test/src/test/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListenerTests.java
index 8d8a54b2c94..ff9f3b953f2 100644
--- a/spring-test/src/test/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListenerTests.java
+++ b/spring-test/src/test/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListenerTests.java
@@ -25,6 +25,7 @@ import org.springframework.test.context.TestContext;
import org.springframework.test.context.jdbc.SqlConfig.TransactionMode;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
+import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.BDDMockito.given;
@@ -34,6 +35,7 @@ import static org.mockito.Mockito.mock;
* Unit tests for {@link SqlScriptsTestExecutionListener}.
*
* @author Sam Brannen
+ * @author Andreas Ahlenstorf
* @since 4.1
*/
class SqlScriptsTestExecutionListenerTests {
@@ -56,6 +58,7 @@ class SqlScriptsTestExecutionListenerTests {
void missingValueAndScriptsAndStatementsAtMethodLevel() throws Exception {
Class> clazz = MissingValueAndScriptsAndStatementsAtMethodLevel.class;
BDDMockito.> given(testContext.getTestClass()).willReturn(clazz);
+ given(testContext.hasTestMethod()).willReturn(true);
given(testContext.getTestMethod()).willReturn(clazz.getDeclaredMethod("foo"));
assertExceptionContains(clazz.getSimpleName() + ".foo" + ".sql");
@@ -102,6 +105,30 @@ class SqlScriptsTestExecutionListenerTests {
assertExceptionContains("supply at least a DataSource or PlatformTransactionManager");
}
+ @Test
+ void beforeTestClassOnMethod() throws Exception {
+ Class> clazz = ClassLevelExecutionPhaseOnMethod.class;
+ BDDMockito.> given(testContext.getTestClass()).willReturn(clazz);
+ given(testContext.hasTestMethod()).willReturn(true);
+ given(testContext.getTestMethod()).willReturn(clazz.getDeclaredMethod("beforeTestClass"));
+
+ assertThatIllegalArgumentException()
+ .isThrownBy(() -> listener.beforeTestMethod(testContext))
+ .withMessage("BEFORE_TEST_CLASS cannot be used on methods");
+ }
+
+ @Test
+ void afterTestClassOnMethod() throws Exception {
+ Class> clazz = ClassLevelExecutionPhaseOnMethod.class;
+ BDDMockito.> given(testContext.getTestClass()).willReturn(clazz);
+ given(testContext.hasTestMethod()).willReturn(true);
+ given(testContext.getTestMethod()).willReturn(clazz.getDeclaredMethod("afterTestClass"));
+
+ assertThatIllegalArgumentException()
+ .isThrownBy(() -> listener.beforeTestMethod(testContext))
+ .withMessage("AFTER_TEST_CLASS cannot be used on methods");
+ }
+
private void assertExceptionContains(String msg) throws Exception {
assertThatIllegalStateException().isThrownBy(() ->
listener.beforeTestMethod(testContext))
@@ -146,4 +173,14 @@ class SqlScriptsTestExecutionListenerTests {
}
}
+ static class ClassLevelExecutionPhaseOnMethod {
+
+ @Sql(scripts = "foo.sql", executionPhase = Sql.ExecutionPhase.BEFORE_TEST_CLASS)
+ public void beforeTestClass() {
+ }
+
+ @Sql(scripts = "foo.sql", executionPhase = Sql.ExecutionPhase.AFTER_TEST_CLASS)
+ public void afterTestClass() {
+ }
+ }
}