diff --git a/openid/src/main/java/org/springframework/security/openid/OpenIDAuthenticationFilter.java b/openid/src/main/java/org/springframework/security/openid/OpenIDAuthenticationFilter.java
index 99bd8068de..503ef6e434 100644
--- a/openid/src/main/java/org/springframework/security/openid/OpenIDAuthenticationFilter.java
+++ b/openid/src/main/java/org/springframework/security/openid/OpenIDAuthenticationFilter.java
@@ -27,8 +27,10 @@ import org.springframework.util.StringUtils;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
+import java.io.UnsupportedEncodingException;
import java.net.MalformedURLException;
import java.net.URL;
+import java.net.URLEncoder;
import java.util.*;
@@ -186,32 +188,37 @@ public class OpenIDAuthenticationFilter extends AbstractAuthenticationProcessing
* @return The return_to URL.
*/
protected String buildReturnToUrl(HttpServletRequest request) {
- StringBuffer sb = request.getRequestURL();
-
- Iterator iterator = returnToUrlParameters.iterator();
- boolean isFirst = true;
-
- while (iterator.hasNext()) {
- String name = iterator.next();
- // Assume for simplicity that there is only one value
- String value = request.getParameter(name);
-
- if (value == null) {
- continue;
- }
-
- if (isFirst) {
- sb.append("?");
- isFirst = false;
- }
- sb.append(name).append("=").append(value);
-
- if (iterator.hasNext()) {
- sb.append("&");
+ try {
+ StringBuffer sb = request.getRequestURL();
+
+ Iterator iterator = returnToUrlParameters.iterator();
+ boolean isFirst = true;
+
+ while (iterator.hasNext()) {
+ String name = iterator.next();
+ // Assume for simplicity that there is only one value
+ String value = request.getParameter(name);
+
+ if (value == null) {
+ continue;
+ }
+
+ if (isFirst) {
+ sb.append("?");
+ isFirst = false;
+ }
+ sb.append(URLEncoder.encode(name, "UTF-8")).append("=").append(URLEncoder.encode(value, "UTF-8"));
+
+ if (iterator.hasNext()) {
+ sb.append("&");
+ }
}
+ return sb.toString();
+ } catch(UnsupportedEncodingException e) {
+ Error err = new AssertionError("The Java platform guarantees UTF-8 support, but it seemingly is not present.");
+ err.initCause(e);
+ throw err;
}
-
- return sb.toString();
}
/**
diff --git a/openid/src/test/java/org/springframework/security/openid/OpenIDAuthenticationFilterTests.java b/openid/src/test/java/org/springframework/security/openid/OpenIDAuthenticationFilterTests.java
index 54efad432a..80adce08b3 100644
--- a/openid/src/test/java/org/springframework/security/openid/OpenIDAuthenticationFilterTests.java
+++ b/openid/src/test/java/org/springframework/security/openid/OpenIDAuthenticationFilterTests.java
@@ -4,6 +4,9 @@ import static org.junit.Assert.assertEquals;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.*;
+import java.net.URI;
+import java.util.Collections;
+
import javax.servlet.FilterChain;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
@@ -63,4 +66,35 @@ public class OpenIDAuthenticationFilterTests {
// Filter chain shouldn't proceed
verify(fc, never()).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class));
}
+
+ /**
+ * Tests that the filter encodes any query parameters on the return_to URL.
+ */
+ @Test
+ public void encodesUrlParameters() throws Exception {
+ // Arbitrary parameter name and value that will both need to be encoded:
+ String paramName = "foo&bar";
+ String paramValue = "http://example.com/path?a=b&c=d";
+ MockHttpServletRequest req = new MockHttpServletRequest("GET", REQUEST_PATH);
+ req.addParameter(paramName, paramValue);
+ filter.setReturnToUrlParameters(Collections.singleton(paramName));
+
+ URI returnTo = new URI(filter.buildReturnToUrl(req));
+ String query = returnTo.getRawQuery();
+ assertEquals(1, count(query, '='));
+ assertEquals(0, count(query, '&'));
+ }
+
+ /**
+ * Counts the number of occurrences of {@code c} in {@code s}.
+ */
+ private static int count(String s, char c) {
+ int count = 0;
+ for(char ch : s.toCharArray()) {
+ if(c == ch) {
+ count += 1;
+ }
+ }
+ return count;
+ }
}