Discard further rows once maxRows has been reached

See https://github.com/spring-projects/spring-framework/issues/34666#issuecomment-2773151317

Signed-off-by: Yanming Zhou <zhouyanming@gmail.com>
This commit is contained in:
Yanming Zhou 2025-04-03 10:30:46 +08:00 committed by Juergen Hoeller
parent d957f8bb5d
commit 88257f7dfd
3 changed files with 87 additions and 17 deletions

View File

@ -102,6 +102,7 @@ import org.springframework.util.StringUtils;
* @author Rod Johnson
* @author Juergen Hoeller
* @author Thomas Risberg
* @author Yanming Zhou
* @since May 3, 2001
* @see JdbcOperations
* @see PreparedStatementCreator
@ -493,12 +494,12 @@ public class JdbcTemplate extends JdbcAccessor implements JdbcOperations {
@Override
public void query(String sql, RowCallbackHandler rch) throws DataAccessException {
query(sql, new RowCallbackHandlerResultSetExtractor(rch));
query(sql, new RowCallbackHandlerResultSetExtractor(rch, this.maxRows));
}
@Override
public <T> List<T> query(String sql, RowMapper<T> rowMapper) throws DataAccessException {
return result(query(sql, new RowMapperResultSetExtractor<>(rowMapper)));
return result(query(sql, new RowMapperResultSetExtractor<>(rowMapper, 0, this.maxRows)));
}
@Override
@ -508,7 +509,7 @@ public class JdbcTemplate extends JdbcAccessor implements JdbcOperations {
public Stream<T> doInStatement(Statement stmt) throws SQLException {
ResultSet rs = stmt.executeQuery(sql);
Connection con = stmt.getConnection();
return new ResultSetSpliterator<>(rs, rowMapper).stream().onClose(() -> {
return new ResultSetSpliterator<>(rs, rowMapper, JdbcTemplate.this.maxRows).stream().onClose(() -> {
JdbcUtils.closeResultSet(rs);
JdbcUtils.closeStatement(stmt);
DataSourceUtils.releaseConnection(con, getDataSource());
@ -773,12 +774,12 @@ public class JdbcTemplate extends JdbcAccessor implements JdbcOperations {
@Override
public void query(PreparedStatementCreator psc, RowCallbackHandler rch) throws DataAccessException {
query(psc, new RowCallbackHandlerResultSetExtractor(rch));
query(psc, new RowCallbackHandlerResultSetExtractor(rch, this.maxRows));
}
@Override
public void query(String sql, @Nullable PreparedStatementSetter pss, RowCallbackHandler rch) throws DataAccessException {
query(sql, pss, new RowCallbackHandlerResultSetExtractor(rch));
query(sql, pss, new RowCallbackHandlerResultSetExtractor(rch, this.maxRows));
}
@Override
@ -799,28 +800,28 @@ public class JdbcTemplate extends JdbcAccessor implements JdbcOperations {
@Override
public <T> List<T> query(PreparedStatementCreator psc, RowMapper<T> rowMapper) throws DataAccessException {
return result(query(psc, new RowMapperResultSetExtractor<>(rowMapper)));
return result(query(psc, new RowMapperResultSetExtractor<>(rowMapper, 0, this.maxRows)));
}
@Override
public <T> List<T> query(String sql, @Nullable PreparedStatementSetter pss, RowMapper<T> rowMapper) throws DataAccessException {
return result(query(sql, pss, new RowMapperResultSetExtractor<>(rowMapper)));
return result(query(sql, pss, new RowMapperResultSetExtractor<>(rowMapper, 0, this.maxRows)));
}
@Override
public <T> List<T> query(String sql, @Nullable Object @Nullable [] args, int[] argTypes, RowMapper<T> rowMapper) throws DataAccessException {
return result(query(sql, args, argTypes, new RowMapperResultSetExtractor<>(rowMapper)));
return result(query(sql, args, argTypes, new RowMapperResultSetExtractor<>(rowMapper, 0, this.maxRows)));
}
@Deprecated(since = "5.3")
@Override
public <T> List<T> query(String sql, @Nullable Object @Nullable [] args, RowMapper<T> rowMapper) throws DataAccessException {
return result(query(sql, newArgPreparedStatementSetter(args), new RowMapperResultSetExtractor<>(rowMapper)));
return result(query(sql, newArgPreparedStatementSetter(args), new RowMapperResultSetExtractor<>(rowMapper, 0, this.maxRows)));
}
@Override
public <T> List<T> query(String sql, RowMapper<T> rowMapper, @Nullable Object @Nullable ... args) throws DataAccessException {
return result(query(sql, newArgPreparedStatementSetter(args), new RowMapperResultSetExtractor<>(rowMapper)));
return result(query(sql, newArgPreparedStatementSetter(args), new RowMapperResultSetExtractor<>(rowMapper, 0, this.maxRows)));
}
/**
@ -845,7 +846,7 @@ public class JdbcTemplate extends JdbcAccessor implements JdbcOperations {
}
ResultSet rs = ps.executeQuery();
Connection con = ps.getConnection();
return new ResultSetSpliterator<>(rs, rowMapper).stream().onClose(() -> {
return new ResultSetSpliterator<>(rs, rowMapper, this.maxRows).stream().onClose(() -> {
JdbcUtils.closeResultSet(rs);
if (pss instanceof ParameterDisposer parameterDisposer) {
parameterDisposer.cleanupParameters();
@ -1364,7 +1365,7 @@ public class JdbcTemplate extends JdbcAccessor implements JdbcOperations {
}
else if (param.getRowCallbackHandler() != null) {
RowCallbackHandler rch = param.getRowCallbackHandler();
(new RowCallbackHandlerResultSetExtractor(rch)).extractData(rs);
(new RowCallbackHandlerResultSetExtractor(rch, -1)).extractData(rs);
return Collections.singletonMap(param.getName(),
"ResultSet returned from stored procedure was processed");
}
@ -1747,13 +1748,17 @@ public class JdbcTemplate extends JdbcAccessor implements JdbcOperations {
private final RowCallbackHandler rch;
public RowCallbackHandlerResultSetExtractor(RowCallbackHandler rch) {
private final int maxRows;
public RowCallbackHandlerResultSetExtractor(RowCallbackHandler rch, int maxRows) {
this.rch = rch;
this.maxRows = maxRows;
}
@Override
public @Nullable Object extractData(ResultSet rs) throws SQLException {
while (rs.next()) {
int processed = 0;
while (rs.next() && (this.maxRows == -1 || (processed++) < this.maxRows)) {
this.rch.processRow(rs);
}
return null;
@ -1771,17 +1776,20 @@ public class JdbcTemplate extends JdbcAccessor implements JdbcOperations {
private final RowMapper<T> rowMapper;
private final int maxRows;
private int rowNum = 0;
public ResultSetSpliterator(ResultSet rs, RowMapper<T> rowMapper) {
public ResultSetSpliterator(ResultSet rs, RowMapper<T> rowMapper, int maxRows) {
this.rs = rs;
this.rowMapper = rowMapper;
this.maxRows = maxRows;
}
@Override
public boolean tryAdvance(Consumer<? super T> action) {
try {
if (this.rs.next()) {
if (this.rs.next() && (this.maxRows == -1 || this.rowNum < this.maxRows)) {
action.accept(this.rowMapper.mapRow(this.rs, this.rowNum++));
return true;
}

View File

@ -52,6 +52,7 @@ import org.springframework.util.Assert;
* you can have executable query objects (containing row-mapping logic) there.
*
* @author Juergen Hoeller
* @author Yanming Zhou
* @since 1.0.2
* @param <T> the result element type
* @see RowMapper
@ -64,6 +65,8 @@ public class RowMapperResultSetExtractor<T> implements ResultSetExtractor<List<T
private final int rowsExpected;
private final int maxRows;
/**
* Create a new RowMapperResultSetExtractor.
@ -80,9 +83,21 @@ public class RowMapperResultSetExtractor<T> implements ResultSetExtractor<List<T
* (just used for optimized collection handling)
*/
public RowMapperResultSetExtractor(RowMapper<T> rowMapper, int rowsExpected) {
this(rowMapper, rowsExpected, -1);
}
/**
* Create a new RowMapperResultSetExtractor.
* @param rowMapper the RowMapper which creates an object for each row
* @param rowsExpected the number of expected rows
* (just used for optimized collection handling)
* @param maxRows the number of max rows
*/
public RowMapperResultSetExtractor(RowMapper<T> rowMapper, int rowsExpected, int maxRows) {
Assert.notNull(rowMapper, "RowMapper must not be null");
this.rowMapper = rowMapper;
this.rowsExpected = rowsExpected;
this.maxRows = maxRows;
}
@ -90,7 +105,7 @@ public class RowMapperResultSetExtractor<T> implements ResultSetExtractor<List<T
public List<T> extractData(ResultSet rs) throws SQLException {
List<T> results = (this.rowsExpected > 0 ? new ArrayList<>(this.rowsExpected) : new ArrayList<>());
int rowNum = 0;
while (rs.next()) {
while (rs.next() && (this.maxRows == -1 || rowNum < this.maxRows)) {
results.add(this.rowMapper.mapRow(rs, rowNum++));
}
return results;

View File

@ -32,7 +32,9 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Consumer;
import java.util.stream.Stream;
import javax.sql.DataSource;
@ -77,6 +79,7 @@ import static org.mockito.Mockito.verify;
* @author Thomas Risberg
* @author Juergen Hoeller
* @author Phillip Webb
* @author Yanming Zhou
*/
class JdbcTemplateTests {
@ -1236,6 +1239,50 @@ class JdbcTemplateTests {
Collections.singletonMap("someId", 456));
}
@Test
void testSkipFurtherRowsOnceMaxRowsHasBeenReachedForRowMapper() throws Exception {
testDiscardFurtherRowsOnceMaxRowsHasBeenReached((template, sql) ->
template.query(sql, (rs, rowNum) -> rs.getString(1)));
}
@Test
void testDiscardFurtherRowsOnceMaxRowsHasBeenReachedForRowCallbackHandler() throws Exception {
testDiscardFurtherRowsOnceMaxRowsHasBeenReached((template, sql) -> {
List<String> list = new ArrayList<>();
template.query(sql, (RowCallbackHandler) rs -> list.add(rs.getString(1)));
return list;
});
}
@Test
void testDiscardFurtherRowsOnceMaxRowsHasBeenReachedForStream() throws Exception {
testDiscardFurtherRowsOnceMaxRowsHasBeenReached((template, sql) -> {
try (Stream<String> stream = template.queryForStream(sql, (rs, rowNum) -> rs.getString(1))) {
return stream.toList();
}
});
}
private void testDiscardFurtherRowsOnceMaxRowsHasBeenReached(BiFunction<JdbcTemplate,String,List<String>> function) throws Exception {
String sql = "SELECT FORENAME FROM CUSTMR";
String[] results = {"rod", "gary", " portia"};
int maxRows = 2;
given(this.resultSet.next()).willReturn(true, true, true, false);
given(this.resultSet.getString(1)).willReturn(results[0], results[1], results[2]);
given(this.connection.createStatement()).willReturn(this.preparedStatement);
JdbcTemplate template = new JdbcTemplate();
template.setDataSource(this.dataSource);
template.setMaxRows(maxRows);
assertThat(function.apply(template, sql)).as("same length").hasSize(maxRows);
verify(this.resultSet).close();
verify(this.preparedStatement).close();
verify(this.connection).close();
}
private void mockDatabaseMetaData(boolean supportsBatchUpdates) throws SQLException {
DatabaseMetaData databaseMetaData = mock();
given(databaseMetaData.getDatabaseProductName()).willReturn("MySQL");