diff --git a/org.springframework.jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterUtils.java b/org.springframework.jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterUtils.java index aeed1659222..07e1ad5942a 100644 --- a/org.springframework.jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterUtils.java +++ b/org.springframework.jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterUtils.java @@ -16,6 +16,7 @@ package org.springframework.jdbc.core.namedparam; +import java.util.ArrayList; import java.util.Collection; import java.util.HashSet; import java.util.Iterator; @@ -69,15 +70,12 @@ public abstract class NamedParameterUtils { * @param sql the SQL statement * @return the parsed statement, represented as ParsedSql instance */ - public static ParsedSql parseSqlStatement(String sql) { + public static ParsedSql parseSqlStatement(final String sql) { Assert.notNull(sql, "SQL must not be null"); Set namedParameters = new HashSet(); String sqlToUse = sql; - if (sql.contains("\\:")) { - sqlToUse = sql.replace("\\:", ":"); - } - ParsedSql parsedSql = new ParsedSql(sqlToUse); + List parameterList = new ArrayList(); char[] statement = sql.toCharArray(); int namedParameterCount = 0; @@ -111,7 +109,7 @@ public abstract class NamedParameterUtils { if (j - i > 3) { parameter = sql.substring(i + 2, j); namedParameterCount = addNewNamedParameter(namedParameters, namedParameterCount, parameter); - totalParameterCount = addNamedParameter(parsedSql, totalParameterCount, escapes, i, j + 1, parameter); + totalParameterCount = addNamedParameter(parameterList, totalParameterCount, escapes, i, j + 1, parameter); } j++; } @@ -122,7 +120,7 @@ public abstract class NamedParameterUtils { if (j - i > 1) { parameter = sql.substring(i + 1, j); namedParameterCount = addNewNamedParameter(namedParameters, namedParameterCount, parameter); - totalParameterCount = addNamedParameter(parsedSql, totalParameterCount, escapes, i, j, parameter); + totalParameterCount = addNamedParameter(parameterList, totalParameterCount, escapes, i, j, parameter); } } i = j - 1; @@ -132,6 +130,7 @@ public abstract class NamedParameterUtils { int j = i + 1; if (j < statement.length && statement[j] == ':') { // this is an escaped : and should be skipped + sqlToUse = sqlToUse.substring(0, i - escapes) + sqlToUse.substring(i - escapes + 1); escapes++; i = i + 2; continue; @@ -144,20 +143,24 @@ public abstract class NamedParameterUtils { } i++; } + ParsedSql parsedSql = new ParsedSql(sqlToUse); + for (ParameterHolder ph : parameterList) { + parsedSql.addNamedParameter(ph.getParameterName(), ph.getStartIndex(), ph.getEndIndex()); + } parsedSql.setNamedParameterCount(namedParameterCount); parsedSql.setUnnamedParameterCount(unnamedParameterCount); parsedSql.setTotalParameterCount(totalParameterCount); return parsedSql; } - protected static int addNamedParameter(ParsedSql parsedSql, int totalParameterCount, int escapes, int i, int j, + private static int addNamedParameter(List parameterList, int totalParameterCount, int escapes, int i, int j, String parameter) { - parsedSql.addNamedParameter(parameter, i - escapes, j - escapes); + parameterList.add(new ParameterHolder(parameter, i - escapes, j - escapes)); totalParameterCount++; return totalParameterCount; } - protected static int addNewNamedParameter(Set namedParameters, int namedParameterCount, String parameter) { + private static int addNewNamedParameter(Set namedParameters, int namedParameterCount, String parameter) { if (!namedParameters.contains(parameter)) { namedParameters.add(parameter); namedParameterCount++; @@ -445,4 +448,28 @@ public abstract class NamedParameterUtils { return buildValueArray(parsedSql, new MapSqlParameterSource(paramMap), null); } + private static class ParameterHolder { + private String parameterName; + private int startIndex; + private int endIndex; + + public ParameterHolder(String parameterName, int startIndex, int endIndex) { + super(); + this.parameterName = parameterName; + this.startIndex = startIndex; + this.endIndex = endIndex; + } + + public String getParameterName() { + return parameterName; + } + + public int getStartIndex() { + return startIndex; + } + + public int getEndIndex() { + return endIndex; + } + } } diff --git a/org.springframework.jdbc/src/test/java/org/springframework/jdbc/core/namedparam/NamedParameterUtilsTests.java b/org.springframework.jdbc/src/test/java/org/springframework/jdbc/core/namedparam/NamedParameterUtilsTests.java index b7f1b7fd6e3..c5f14ac6f86 100644 --- a/org.springframework.jdbc/src/test/java/org/springframework/jdbc/core/namedparam/NamedParameterUtilsTests.java +++ b/org.springframework.jdbc/src/test/java/org/springframework/jdbc/core/namedparam/NamedParameterUtilsTests.java @@ -189,8 +189,8 @@ public class NamedParameterUtilsTests { */ @Test public void parseSqlStatementWithEscapedColon() throws Exception { - String expectedSql = "select foo from bar where baz < DATE(? 23:59:59) and baz = ?"; - String sql = "select foo from bar where baz < DATE(:p1 23\\:59\\:59) and baz = :p2"; + String expectedSql = "select '0\\:0' as a, foo from bar where baz < DATE(? 23:59:59) and baz = ?"; + String sql = "select '0\\:0' as a, foo from bar where baz < DATE(:p1 23\\:59\\:59) and baz = :p2"; ParsedSql parsedSql = NamedParameterUtils.parseSqlStatement(sql); assertEquals(2, parsedSql.getParameterNames().size());