diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/PreparedStatementCreatorFactory.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/PreparedStatementCreatorFactory.java index 3e18a956cb8..98fcc67882e 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/PreparedStatementCreatorFactory.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/PreparedStatementCreatorFactory.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,7 +22,6 @@ import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Types; import java.util.Arrays; -import java.util.Collection; import java.util.Collections; import java.util.HashSet; import java.util.LinkedList; @@ -32,6 +31,7 @@ import java.util.Set; import org.springframework.dao.InvalidDataAccessApiUsageException; import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import org.springframework.util.ObjectUtils; /** * Helper class that efficiently creates multiple {@link PreparedStatementCreator} @@ -268,13 +268,19 @@ public class PreparedStatementCreatorFactory { } declaredParameter = declaredParameters.get(i); } - if (in instanceof Collection && declaredParameter.getSqlType() != Types.ARRAY) { - Collection entries = (Collection) in; + if (in != null && in.getClass().isArray()) { + in = Arrays.asList(ObjectUtils.toObjectArray(in)); + } + if (in instanceof Iterable && declaredParameter.getSqlType() != Types.ARRAY) { + Iterable entries = (Iterable) in; for (Object entry : entries) { - if (entry instanceof Object[]) { - Object[] valueArray = ((Object[])entry); - for (Object argValue : valueArray) { - StatementCreatorUtils.setParameterValue(ps, sqlColIndx++, declaredParameter, argValue); + if (entry != null && entry.getClass().isArray()) { + entry = Arrays.asList(ObjectUtils.toObjectArray(entry)); + } + if (entry instanceof Iterable) { + Iterable values = (Iterable) entry; + for (Object value : values) { + StatementCreatorUtils.setParameterValue(ps, sqlColIndx++, declaredParameter, value); } } else { diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterUtils.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterUtils.java index 3fbc4c6b0f5..b45b8ffd855 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterUtils.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterUtils.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2018 the original author or authors. + * Copyright 2002-2019 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ package org.springframework.jdbc.core.namedparam; import java.util.ArrayList; -import java.util.Collection; +import java.util.Arrays; import java.util.HashSet; import java.util.Iterator; import java.util.List; @@ -29,6 +29,7 @@ import org.springframework.jdbc.core.SqlParameter; import org.springframework.jdbc.core.SqlParameterValue; import org.springframework.lang.Nullable; import org.springframework.util.Assert; +import org.springframework.util.ObjectUtils; /** * Helper methods for named parameter parsing. @@ -284,8 +285,11 @@ public abstract class NamedParameterUtils { if (value instanceof SqlParameterValue) { value = ((SqlParameterValue) value).getValue(); } - if (value instanceof Collection) { - Iterator entryIter = ((Collection) value).iterator(); + if (value != null && value.getClass().isArray()) { + value = Arrays.asList(ObjectUtils.toObjectArray(value)); + } + if (value instanceof Iterable) { + Iterator entryIter = ((Iterable) value).iterator(); int k = 0; while (entryIter.hasNext()) { if (k > 0) { diff --git a/spring-jdbc/src/test/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplateTests.java b/spring-jdbc/src/test/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplateTests.java index ce2eeaa1452..d76a4c85945 100644 --- a/spring-jdbc/src/test/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplateTests.java +++ b/spring-jdbc/src/test/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplateTests.java @@ -475,11 +475,12 @@ public class NamedParameterJdbcTemplateTests { @Test public void testBatchUpdateWithInClause() throws Exception { @SuppressWarnings("unchecked") - Map[] parameters = new Map[2]; + Map[] parameters = new Map[3]; parameters[0] = Collections.singletonMap("ids", Arrays.asList(1, 2)); - parameters[1] = Collections.singletonMap("ids", Arrays.asList(3, 4)); + parameters[1] = Collections.singletonMap("ids", new Integer[] {3, 4}); + parameters[2] = Collections.singletonMap("ids", (Iterable) () -> Arrays.asList(5, 6).iterator()); - final int[] rowsAffected = new int[] {1, 2}; + final int[] rowsAffected = new int[] {1, 2, 3}; given(preparedStatement.executeBatch()).willReturn(rowsAffected); given(connection.getMetaData()).willReturn(databaseMetaData); @@ -491,7 +492,7 @@ public class NamedParameterJdbcTemplateTests { parameters ); - assertThat(actualRowsAffected.length).as("executed 2 updates").isEqualTo(2); + assertThat(actualRowsAffected.length).as("executed 3 updates").isEqualTo(3); InOrder inOrder = inOrder(preparedStatement); @@ -503,6 +504,10 @@ public class NamedParameterJdbcTemplateTests { inOrder.verify(preparedStatement).setObject(2, 4); inOrder.verify(preparedStatement).addBatch(); + inOrder.verify(preparedStatement).setObject(1, 5); + inOrder.verify(preparedStatement).setObject(2, 6); + inOrder.verify(preparedStatement).addBatch(); + inOrder.verify(preparedStatement, atLeastOnce()).close(); verify(connection, atLeastOnce()).close(); }