@ -128,11 +128,11 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
@@ -128,11 +128,11 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
private void executeSqlScripts ( TestContext testContext , ExecutionPhase executionPhase ) throws Exception {
boolean classLevel = false ;
Set < Sql > sqlAnnotations = AnnotationUtils . getRepeatableAnnotations ( testContext . getTestMethod ( ) , Sql . class ,
SqlGroup . class ) ;
Set < Sql > sqlAnnotations = AnnotationUtils . getRepeatableAnnotations (
testContext . getTestMethod ( ) , Sql . class , SqlGroup . class ) ;
if ( sqlAnnotations . isEmpty ( ) ) {
sqlAnnotations = AnnotationUtils . getRepeatableAnnotations ( testContext . getTestClass ( ) , Sql . class ,
SqlGroup . class ) ;
sqlAnnotations = AnnotationUtils . getRepeatableAnnotations (
testContext . getTestClass ( ) , Sql . class , SqlGroup . class ) ;
if ( ! sqlAnnotations . isEmpty ( ) ) {
classLevel = true ;
}
@ -155,14 +155,15 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
@@ -155,14 +155,15 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
* /
private void executeSqlScripts ( Sql sql , ExecutionPhase executionPhase , TestContext testContext , boolean classLevel )
throws Exception {
if ( executionPhase ! = sql . executionPhase ( ) ) {
return ;
}
MergedSqlConfig mergedSqlConfig = new MergedSqlConfig ( sql . config ( ) , testContext . getTestClass ( ) ) ;
if ( logger . isDebugEnabled ( ) ) {
logger . debug ( String . format ( "Processing %s for execution phase [%s] and test context %s." , mergedSqlConfig ,
executionPhase , testContext ) ) ;
logger . debug ( String . format ( "Processing %s for execution phase [%s] and test context %s." ,
mergedSqlConfig , executionPhase , testContext ) ) ;
}
final ResourceDatabasePopulator populator = new ResourceDatabasePopulator ( ) ;
@ -177,15 +178,13 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
@@ -177,15 +178,13 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
String [ ] scripts = getScripts ( sql , testContext , classLevel ) ;
scripts = TestContextResourceUtils . convertToClasspathResourcePaths ( testContext . getTestClass ( ) , scripts ) ;
List < Resource > scriptResources = TestContextResourceUtils . convertToResourceList (
testContext . getApplicationContext ( ) , scripts ) ;
for ( String statement : sql . statements ( ) ) {
if ( StringUtils . hasText ( statement ) ) {
statement = statement . trim ( ) ;
scriptResources . add ( new ByteArrayResource ( statement . getBytes ( ) , "from inlined SQL statement: " + statement ) ) ;
testContext . getApplicationContext ( ) , scripts ) ;
for ( String stmt : sql . statements ( ) ) {
if ( StringUtils . hasText ( stmt ) ) {
stmt = stmt . trim ( ) ;
scriptResources . add ( new ByteArrayResource ( stmt . getBytes ( ) , "from inlined SQL statement: " + stmt ) ) ;
}
}
populator . setScripts ( scriptResources . toArray ( new Resource [ scriptResources . size ( ) ] ) ) ;
if ( logger . isDebugEnabled ( ) ) {
logger . debug ( "Executing SQL scripts: " + ObjectUtils . nullSafeToString ( scriptResources ) ) ;
@ -194,54 +193,45 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
@@ -194,54 +193,45 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
String dsName = mergedSqlConfig . getDataSource ( ) ;
String tmName = mergedSqlConfig . getTransactionManager ( ) ;
DataSource dataSource = TestContextTransactionUtils . retrieveDataSource ( testContext , dsName ) ;
final PlatformTransactionManager transactionManager = TestContextTransactionUtils . retrieveTransactionManager (
testContext , tmName ) ;
final boolean newTxRequired = mergedSqlConfig . getTransactionMode ( ) = = TransactionMode . ISOLATED ;
PlatformTransactionManager txMgr = TestContextTransactionUtils . retrieveTransactionManager ( testContext , tmName ) ;
boolean newTxRequired = ( mergedSqlConfig . getTransactionMode ( ) = = TransactionMode . ISOLATED ) ;
if ( transactionManage r = = null ) {
if ( txMg r = = null ) {
if ( newTxRequired ) {
throw new IllegalStateException ( String . format ( "Failed to execute SQL scripts for test context %s: "
+ "cannot execute SQL scripts using Transaction Mode "
+ "[%s] without a PlatformTransactionManager." , testContext , TransactionMode . ISOLATED ) ) ;
throw new IllegalStateException ( String . format ( "Failed to execute SQL scripts for test context %s: " +
"cannot execute SQL scripts using Transaction Mode [%s] without a PlatformTransactionManager. " ,
testContext , TransactionMode . ISOLATED ) ) ;
}
if ( dataSource = = null ) {
throw new IllegalStateException ( String . format ( "Failed to execute SQL scripts for test context %s: "
+ "supply at least a DataSource or PlatformTransactionManager." , testContext ) ) ;
throw new IllegalStateException ( String . format ( "Failed to execute SQL scripts for test context %s: " +
"supply at least a DataSource or PlatformTransactionManager." , testContext ) ) ;
}
// Execute scripts directly against the DataSource
populator . execute ( dataSource ) ;
}
else {
DataSource dataSourceFromTxMgr = getDataSourceFromTransactionManager ( transactionManager ) ;
DataSource dataSourceFromTxMgr = getDataSourceFromTransactionManager ( txMgr ) ;
// Ensure user configured an appropriate DataSource/TransactionManager pair.
if ( dataSource ! = null & & dataSourceFromTxMgr ! = null & & ! dataSource . equals ( dataSourceFromTxMgr ) ) {
throw new IllegalStateException ( String . format ( "Failed to execute SQL scripts for test context %s: " +
"the configured DataSource [%s] (named '%s') is not the one associated with " +
"transaction manager [%s] (named '%s')." , testContext , dataSource . getClass ( ) . getName ( ) ,
dsName , transactionManage r . getClass ( ) . getName ( ) , tmName ) ) ;
dsName , txMg r . getClass ( ) . getName ( ) , tmName ) ) ;
}
if ( dataSource = = null ) {
dataSource = dataSourceFromTxMgr ;
if ( dataSource = = null ) {
throw new IllegalStateException ( String . format ( "Failed to execute SQL scripts for test context %s: "
+ "could not obtain DataSource from transaction manager [%s] (named '%s')." , testContext ,
transactionManage r. getClass ( ) . getName ( ) , tmName ) ) ;
throw new IllegalStateException ( String . format ( "Failed to execute SQL scripts for " +
"test context %s: could not obtain DataSource from transaction manager [%s] (named '%s')." ,
testContext , txMg r . getClass ( ) . getName ( ) , tmName ) ) ;
}
}
final DataSource finalDataSource = dataSource ;
int propagation = ( newTxRequired ? TransactionDefinition . PROPAGATION_REQUIRES_NEW :
TransactionDefinition . PROPAGATION_REQUIRED ) ;
TransactionAttribute transactionAttribute = TestContextTransactionUtils . createDelegatingTransactionAttribute (
testContext , new DefaultTransactionAttribute ( propagation ) ) ;
new TransactionTemplate ( transactionManager , transactionAttribute ) . execute ( new TransactionCallbackWithoutResult ( ) {
TransactionAttribute txAttr = TestContextTransactionUtils . createDelegatingTransactionAttribute (
testContext , new DefaultTransactionAttribute ( propagation ) ) ;
new TransactionTemplate ( txMgr , txAttr ) . execute ( new TransactionCallbackWithoutResult ( ) {
@Override
public void doInTransactionWithoutResult ( TransactionStatus status ) {
populator . execute ( finalDataSource ) ;
@ -267,7 +257,7 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
@@ -267,7 +257,7 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
private String [ ] getScripts ( Sql sql , TestContext testContext , boolean classLevel ) {
String [ ] scripts = sql . scripts ( ) ;
if ( ObjectUtils . isEmpty ( scripts ) & & ObjectUtils . isEmpty ( sql . statements ( ) ) ) {
scripts = new String [ ] { detectDefaultScript ( testContext , classLevel ) } ;
scripts = new String [ ] { detectDefaultScript ( testContext , classLevel ) } ;
}
return scripts ;
}
@ -293,8 +283,8 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
@@ -293,8 +283,8 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
if ( classPathResource . exists ( ) ) {
if ( logger . isInfoEnabled ( ) ) {
logger . info ( String . format ( "Detected default SQL script \"%s\" for test %s [%s]" , prefixedResourcePath ,
elementType , elementName ) ) ;
logger . info ( String . format ( "Detected default SQL script \"%s\" for test %s [%s]" ,
prefixedResourcePath , elementType , elementName ) ) ;
}
return prefixedResourcePath ;
}