Merge class-level and method-level @Sql declarations

See gh-1835
This commit is contained in:
asympro 2018-05-25 16:33:41 +03:00 committed by Sam Brannen
parent b0939a8af0
commit d77b715d38
5 changed files with 137 additions and 20 deletions

View File

@ -31,7 +31,8 @@ import org.springframework.core.annotation.AliasFor;
* SQL {@link #scripts} and {@link #statements} to be executed against a given * SQL {@link #scripts} and {@link #statements} to be executed against a given
* database during integration tests. * database during integration tests.
* *
* <p>Method-level declarations override class-level declarations. * <p>Method-level declarations override class-level declarations by default.
* This behaviour can be adjusted via {@link MergeMode}
* *
* <p>Script execution is performed by the {@link SqlScriptsTestExecutionListener}, * <p>Script execution is performed by the {@link SqlScriptsTestExecutionListener},
* which is enabled by default. * which is enabled by default.
@ -146,6 +147,13 @@ public @interface Sql {
*/ */
SqlConfig config() default @SqlConfig; SqlConfig config() default @SqlConfig;
/**
* Indicates whether this annotation should be merged with upper-level annotations
* or override them.
* <p>Defaults to {@link MergeMode#OVERRIDE}.
*/
MergeMode mergeMode() default MergeMode.OVERRIDE;
/** /**
* Enumeration of <em>phases</em> that dictate when SQL scripts are executed. * Enumeration of <em>phases</em> that dictate when SQL scripts are executed.
@ -165,4 +173,23 @@ public @interface Sql {
AFTER_TEST_METHOD AFTER_TEST_METHOD
} }
/**
* Enumeration of <em>modes</em> that dictate whether or not
* declared SQL {@link #scripts} and {@link #statements} are merged
* with the upper-level annotations.
*/
enum MergeMode {
/**
* Indicates that locally declared SQL {@link #scripts} and {@link #statements}
* should override the upper-level (e.g. Class-level) annotations.
*/
OVERRIDE,
/**
* Indicates that locally declared SQL {@link #scripts} and {@link #statements}
* should be merged the upper-level (e.g. Class-level) annotations.
*/
MERGE
}
} }

View File

@ -16,14 +16,17 @@
package org.springframework.test.context.jdbc; package org.springframework.test.context.jdbc;
import java.lang.reflect.AnnotatedElement;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors;
import javax.sql.DataSource; import javax.sql.DataSource;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
import org.jetbrains.annotations.NotNull;
import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContext;
import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.core.io.ByteArrayResource; import org.springframework.core.io.ByteArrayResource;
@ -126,19 +129,35 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
* {@link TestContext} and {@link ExecutionPhase}. * {@link TestContext} and {@link ExecutionPhase}.
*/ */
private void executeSqlScripts(TestContext testContext, ExecutionPhase executionPhase) throws Exception { private void executeSqlScripts(TestContext testContext, ExecutionPhase executionPhase) throws Exception {
boolean classLevel = false; Set<Sql> methodLevelSqls = getScriptsFromElement(testContext.getTestMethod());
List<Sql> methodLevelOverrides = methodLevelSqls.stream()
Set<Sql> sqlAnnotations = AnnotatedElementUtils.getMergedRepeatableAnnotations( .filter(s -> s.executionPhase() == executionPhase)
testContext.getTestMethod(), Sql.class, SqlGroup.class); .filter(s -> s.mergeMode() == Sql.MergeMode.OVERRIDE)
if (sqlAnnotations.isEmpty()) { .collect(Collectors.toList());
sqlAnnotations = AnnotatedElementUtils.getMergedRepeatableAnnotations( if (methodLevelOverrides.isEmpty()) {
testContext.getTestClass(), Sql.class, SqlGroup.class); executeScripts(getScriptsFromElement(testContext.getTestClass()), testContext, executionPhase, true);
if (!sqlAnnotations.isEmpty()) { executeScripts(methodLevelSqls, testContext, executionPhase, false);
classLevel = true; } else {
} executeScripts(methodLevelOverrides, testContext, executionPhase, false);
} }
}
for (Sql sql : sqlAnnotations) { /**
* Get SQL scripts configured via {@link Sql @Sql} for the supplied
* {@link AnnotatedElement}.
*/
private Set<Sql> getScriptsFromElement(AnnotatedElement annotatedElement) throws Exception {
return AnnotatedElementUtils.getMergedRepeatableAnnotations(annotatedElement, Sql.class, SqlGroup.class);
}
/**
* Execute given {@link Sql @Sql} scripts.
* {@link AnnotatedElement}.
*/
private void executeScripts(Iterable<Sql> scripts, TestContext testContext, ExecutionPhase executionPhase,
boolean classLevel)
throws Exception {
for (Sql sql : scripts) {
executeSqlScripts(sql, executionPhase, testContext, classLevel); executeSqlScripts(sql, executionPhase, testContext, classLevel);
} }
} }
@ -166,14 +185,7 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
mergedSqlConfig, executionPhase, testContext)); mergedSqlConfig, executionPhase, testContext));
} }
final ResourceDatabasePopulator populator = new ResourceDatabasePopulator(); final ResourceDatabasePopulator populator = configurePopulator(mergedSqlConfig);
populator.setSqlScriptEncoding(mergedSqlConfig.getEncoding());
populator.setSeparator(mergedSqlConfig.getSeparator());
populator.setCommentPrefix(mergedSqlConfig.getCommentPrefix());
populator.setBlockCommentStartDelimiter(mergedSqlConfig.getBlockCommentStartDelimiter());
populator.setBlockCommentEndDelimiter(mergedSqlConfig.getBlockCommentEndDelimiter());
populator.setContinueOnError(mergedSqlConfig.getErrorMode() == ErrorMode.CONTINUE_ON_ERROR);
populator.setIgnoreFailedDrops(mergedSqlConfig.getErrorMode() == ErrorMode.IGNORE_FAILED_DROPS);
String[] scripts = getScripts(sql, testContext, classLevel); String[] scripts = getScripts(sql, testContext, classLevel);
scripts = TestContextResourceUtils.convertToClasspathResourcePaths(testContext.getTestClass(), scripts); scripts = TestContextResourceUtils.convertToClasspathResourcePaths(testContext.getTestClass(), scripts);
@ -232,6 +244,19 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen
} }
} }
@NotNull
private ResourceDatabasePopulator configurePopulator(MergedSqlConfig mergedSqlConfig) {
final ResourceDatabasePopulator populator = new ResourceDatabasePopulator();
populator.setSqlScriptEncoding(mergedSqlConfig.getEncoding());
populator.setSeparator(mergedSqlConfig.getSeparator());
populator.setCommentPrefix(mergedSqlConfig.getCommentPrefix());
populator.setBlockCommentStartDelimiter(mergedSqlConfig.getBlockCommentStartDelimiter());
populator.setBlockCommentEndDelimiter(mergedSqlConfig.getBlockCommentEndDelimiter());
populator.setContinueOnError(mergedSqlConfig.getErrorMode() == ErrorMode.CONTINUE_ON_ERROR);
populator.setIgnoreFailedDrops(mergedSqlConfig.getErrorMode() == ErrorMode.IGNORE_FAILED_DROPS);
return populator;
}
@Nullable @Nullable
private DataSource getDataSourceFromTransactionManager(PlatformTransactionManager transactionManager) { private DataSource getDataSourceFromTransactionManager(PlatformTransactionManager transactionManager) {
try { try {

View File

@ -25,6 +25,7 @@ import org.junit.runners.MethodSorters;
import org.springframework.test.annotation.DirtiesContext; import org.springframework.test.annotation.DirtiesContext;
import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.AbstractTransactionalJUnit4SpringContextTests; import org.springframework.test.context.junit4.AbstractTransactionalJUnit4SpringContextTests;
import org.springframework.test.jdbc.JdbcTestUtils;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
@ -58,6 +59,10 @@ public class RepeatableSqlAnnotationSqlScriptsTests extends AbstractTransactiona
assertNumUsers(2); assertNumUsers(2);
} }
protected int countRowsInTable(String tableName) {
return JdbcTestUtils.countRowsInTable(this.jdbcTemplate, tableName);
}
protected void assertNumUsers(int expected) { protected void assertNumUsers(int expected) {
assertThat(countRowsInTable("user")).as("Number of rows in the 'user' table.").isEqualTo(expected); assertThat(countRowsInTable("user")).as("Number of rows in the 'user' table.").isEqualTo(expected);
} }

View File

@ -0,0 +1,30 @@
package org.springframework.test.context.jdbc;
import org.junit.Test;
import org.springframework.test.annotation.DirtiesContext;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.AbstractTransactionalJUnit4SpringContextTests;
import static org.junit.Assert.assertEquals;
/**
* Test to verify method level merge of @Sql annotations.
*
* @author Dmitry Semukhin
*/
@ContextConfiguration(classes = EmptyDatabaseConfig.class)
@Sql(value = {"schema.sql", "data-add-catbert.sql"})
@DirtiesContext
public class SqlMethodMergeTest extends AbstractTransactionalJUnit4SpringContextTests {
@Test
@Sql(value = "data-add-dogbert.sql", mergeMode = Sql.MergeMode.MERGE)
public void testMerge() {
assertNumUsers(2);
}
protected void assertNumUsers(int expected) {
assertEquals("Number of rows in the 'user' table.", expected, countRowsInTable("user"));
}
}

View File

@ -0,0 +1,30 @@
package org.springframework.test.context.jdbc;
import org.junit.Test;
import org.springframework.test.annotation.DirtiesContext;
import org.springframework.test.context.ContextConfiguration;
import org.springframework.test.context.junit4.AbstractTransactionalJUnit4SpringContextTests;
import static org.junit.Assert.assertEquals;
/**
* Test to verify method level override of @Sql annotations.
*
* @author Dmitry Semukhin
*/
@ContextConfiguration(classes = EmptyDatabaseConfig.class)
@Sql(value = {"schema.sql", "data-add-catbert.sql"})
@DirtiesContext
public class SqlMethodOverrideTest extends AbstractTransactionalJUnit4SpringContextTests {
@Test
@Sql(value = {"schema.sql", "data.sql", "data-add-dogbert.sql", "data-add-catbert.sql"}, mergeMode = Sql.MergeMode.OVERRIDE)
public void testMerge() {
assertNumUsers(3);
}
protected void assertNumUsers(int expected) {
assertEquals("Number of rows in the 'user' table.", expected, countRowsInTable("user"));
}
}