From 66292cd7a1697a8d99b3dd3eaff4706e4beab558 Mon Sep 17 00:00:00 2001 From: Juergen Hoeller Date: Thu, 12 Nov 2020 14:05:31 +0100 Subject: [PATCH] Individually apply the SQL type from each SqlParameterSource argument Closes gh-26071 --- .../core/PreparedStatementCreatorFactory.java | 4 +--- .../jdbc/core/namedparam/NamedParameterUtils.java | 4 ++-- .../core/namedparam/SqlParameterSourceUtils.java | 8 ++------ .../NamedParameterJdbcTemplateTests.java | 15 +++++++++------ 4 files changed, 14 insertions(+), 17 deletions(-) diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/PreparedStatementCreatorFactory.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/PreparedStatementCreatorFactory.java index f07ee04cd0a..e6083f8be0f 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/PreparedStatementCreatorFactory.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/PreparedStatementCreatorFactory.java @@ -30,7 +30,6 @@ import java.util.Set; import org.springframework.dao.InvalidDataAccessApiUsageException; import org.springframework.lang.Nullable; -import org.springframework.util.Assert; /** * Helper class that efficiently creates multiple {@link PreparedStatementCreator} @@ -200,9 +199,8 @@ public class PreparedStatementCreatorFactory { public PreparedStatementCreatorImpl(String actualSql, List parameters) { this.actualSql = actualSql; - Assert.notNull(parameters, "Parameters List must not be null"); this.parameters = parameters; - if (this.parameters.size() != declaredParameters.size()) { + if (parameters.size() != declaredParameters.size()) { // Account for named parameters being used multiple times Set names = new HashSet<>(); for (int i = 0; i < parameters.size(); i++) { diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterUtils.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterUtils.java index 23feefd6738..4d4c414ea7a 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterUtils.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterUtils.java @@ -345,9 +345,9 @@ public abstract class NamedParameterUtils { for (int i = 0; i < paramNames.size(); i++) { String paramName = paramNames.get(i); try { - Object value = paramSource.getValue(paramName); SqlParameter param = findParameter(declaredParams, paramName, i); - paramArray[i] = (param != null ? new SqlParameterValue(param, value) : value); + paramArray[i] = (param != null ? new SqlParameterValue(param, paramSource.getValue(paramName)) : + SqlParameterSourceUtils.getTypedValue(paramSource, paramName)); } catch (IllegalArgumentException ex) { throw new InvalidDataAccessApiUsageException( diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/SqlParameterSourceUtils.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/SqlParameterSourceUtils.java index 4ae12a9533a..e2bd60e05ff 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/SqlParameterSourceUtils.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/SqlParameterSourceUtils.java @@ -92,17 +92,13 @@ public abstract class SqlParameterSourceUtils { * @param source the source of parameter values and type information * @param parameterName the name of the parameter * @return the value object + * @see SqlParameterValue */ @Nullable public static Object getTypedValue(SqlParameterSource source, String parameterName) { int sqlType = source.getSqlType(parameterName); if (sqlType != SqlParameterSource.TYPE_UNKNOWN) { - if (source.getTypeName(parameterName) != null) { - return new SqlParameterValue(sqlType, source.getTypeName(parameterName), source.getValue(parameterName)); - } - else { - return new SqlParameterValue(sqlType, source.getValue(parameterName)); - } + return new SqlParameterValue(sqlType, source.getTypeName(parameterName), source.getValue(parameterName)); } else { return source.getValue(parameterName); diff --git a/spring-jdbc/src/test/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplateTests.java b/spring-jdbc/src/test/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplateTests.java index d9dc25f77af..31fa105d005 100644 --- a/spring-jdbc/src/test/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplateTests.java +++ b/spring-jdbc/src/test/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplateTests.java @@ -561,10 +561,11 @@ public class NamedParameterJdbcTemplateTests { @Test public void testBatchUpdateWithSqlParameterSourcePlusTypeInfo() throws Exception { - SqlParameterSource[] ids = new SqlParameterSource[2]; - ids[0] = new MapSqlParameterSource().addValue("id", 100, Types.NUMERIC); - ids[1] = new MapSqlParameterSource().addValue("id", 200, Types.NUMERIC); - final int[] rowsAffected = new int[] {1, 2}; + SqlParameterSource[] ids = new SqlParameterSource[3]; + ids[0] = new MapSqlParameterSource().addValue("id", null, Types.NULL); + ids[1] = new MapSqlParameterSource().addValue("id", 100, Types.NUMERIC); + ids[2] = new MapSqlParameterSource().addValue("id", 200, Types.NUMERIC); + final int[] rowsAffected = new int[] {1, 2, 3}; given(preparedStatement.executeBatch()).willReturn(rowsAffected); given(connection.getMetaData()).willReturn(databaseMetaData); @@ -572,13 +573,15 @@ public class NamedParameterJdbcTemplateTests { int[] actualRowsAffected = namedParameterTemplate.batchUpdate( "UPDATE NOSUCHTABLE SET DATE_DISPATCHED = SYSDATE WHERE ID = :id", ids); - assertThat(actualRowsAffected.length == 2).as("executed 2 updates").isTrue(); + assertThat(actualRowsAffected.length == 3).as("executed 3 updates").isTrue(); assertThat(actualRowsAffected[0]).isEqualTo(rowsAffected[0]); assertThat(actualRowsAffected[1]).isEqualTo(rowsAffected[1]); + assertThat(actualRowsAffected[2]).isEqualTo(rowsAffected[2]); verify(connection).prepareStatement("UPDATE NOSUCHTABLE SET DATE_DISPATCHED = SYSDATE WHERE ID = ?"); + verify(preparedStatement).setNull(1, Types.NULL); verify(preparedStatement).setObject(1, 100, Types.NUMERIC); verify(preparedStatement).setObject(1, 200, Types.NUMERIC); - verify(preparedStatement, times(2)).addBatch(); + verify(preparedStatement, times(3)).addBatch(); verify(preparedStatement, atLeastOnce()).close(); verify(connection, atLeastOnce()).close(); }