Refine StringUtils#uriDecode

Refine the StringUtils#uriDecode method in the following ways:

- Use a StringBuilder instead of ByteArrayOutputStream, and only decode
  %-encoded sequences.
- Use HexFormat.fromHexDigits to decode hex sequences.
- Decode to a byte array that is only allocated if encoded sequences are
  encountered.

Signed-off-by: Patrick Strawderman <pstrawderman@netflix.com>
See gh-34673
This commit is contained in:
Patrick Strawderman 2025-03-28 11:54:16 -07:00 committed by Sébastien Deleuze
parent 0b92a51650
commit 7f1bc20688
2 changed files with 36 additions and 17 deletions

View File

@ -16,7 +16,6 @@
package org.springframework.util;
import java.io.ByteArrayOutputStream;
import java.nio.charset.Charset;
import java.util.ArrayDeque;
import java.util.ArrayList;
@ -25,6 +24,7 @@ import java.util.Collection;
import java.util.Collections;
import java.util.Deque;
import java.util.Enumeration;
import java.util.HexFormat;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.List;
@ -816,38 +816,50 @@ public abstract class StringUtils {
* @see java.net.URLDecoder#decode(String, String)
*/
public static String uriDecode(String source, Charset charset) {
Assert.notNull(charset, "Charset must not be null");
int length = source.length();
if (length == 0) {
return source;
}
Assert.notNull(charset, "Charset must not be null");
ByteArrayOutputStream baos = new ByteArrayOutputStream(length);
StringBuilder output = new StringBuilder(length);
boolean changed = false;
for (int i = 0; i < length; i++) {
int ch = source.charAt(i);
byte[] bytes = null;
int i = 0;
while (i < length) {
char ch = source.charAt(i);
if (ch == '%') {
if (i + 2 < length) {
char hex1 = source.charAt(i + 1);
char hex2 = source.charAt(i + 2);
int u = Character.digit(hex1, 16);
int l = Character.digit(hex2, 16);
if (u == -1 || l == -1) {
throw new IllegalArgumentException("Invalid encoded sequence \"" + source.substring(i) + "\"");
try {
if (bytes == null) {
bytes = new byte[(length - i) / 3];
}
baos.write((char) ((u << 4) + l));
i += 2;
int pos = 0;
while (i + 2 < length && ch == '%') {
bytes[pos++] = (byte) HexFormat.fromHexDigits(source, i + 1, i + 3);
i += 3;
if (i < length) {
ch = source.charAt(i);
}
}
if (i < length && ch == '%') {
throw new IllegalArgumentException("Incomplete trailing escape (%) pattern");
}
output.append(new String(bytes, 0, pos, charset));
changed = true;
}
else {
catch (NumberFormatException ex) {
throw new IllegalArgumentException("Invalid encoded sequence \"" + source.substring(i) + "\"");
}
}
else {
baos.write(ch);
output.append(ch);
i++;
}
}
return (changed ? StreamUtils.copyToString(baos, charset) : source);
return (changed ? output.toString() : source);
}
/**

View File

@ -107,12 +107,19 @@ class UriUtilsTests {
assertThat(UriUtils.decode("T%C5%8Dky%C5%8D", CHARSET)).as("Invalid encoded result").isEqualTo("T\u014dky\u014d");
assertThat(UriUtils.decode("/Z%C3%BCrich", CHARSET)).as("Invalid encoded result").isEqualTo("/Z\u00fcrich");
assertThat(UriUtils.decode("T\u014dky\u014d", CHARSET)).as("Invalid encoded result").isEqualTo("T\u014dky\u014d");
assertThat(UriUtils.decode("%20\u2019", CHARSET)).as("Invalid encoded result").isEqualTo(" \u2019");
}
@Test
void decodeInvalidSequence() {
assertThatIllegalArgumentException().isThrownBy(() ->
UriUtils.decode("foo%2", CHARSET));
assertThatIllegalArgumentException().isThrownBy(() ->
UriUtils.decode("foo%", CHARSET));
assertThatIllegalArgumentException().isThrownBy(() ->
UriUtils.decode("%", CHARSET));
assertThatIllegalArgumentException().isThrownBy(() ->
UriUtils.decode("%zz", CHARSET));
}
@Test