diff --git a/openid/src/main/java/org/springframework/security/openid/OpenIDAuthenticationFilter.java b/openid/src/main/java/org/springframework/security/openid/OpenIDAuthenticationFilter.java index 99bd8068de..503ef6e434 100644 --- a/openid/src/main/java/org/springframework/security/openid/OpenIDAuthenticationFilter.java +++ b/openid/src/main/java/org/springframework/security/openid/OpenIDAuthenticationFilter.java @@ -27,8 +27,10 @@ import org.springframework.util.StringUtils; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.io.IOException; +import java.io.UnsupportedEncodingException; import java.net.MalformedURLException; import java.net.URL; +import java.net.URLEncoder; import java.util.*; @@ -186,32 +188,37 @@ public class OpenIDAuthenticationFilter extends AbstractAuthenticationProcessing * @return The return_to URL. */ protected String buildReturnToUrl(HttpServletRequest request) { - StringBuffer sb = request.getRequestURL(); - - Iterator iterator = returnToUrlParameters.iterator(); - boolean isFirst = true; - - while (iterator.hasNext()) { - String name = iterator.next(); - // Assume for simplicity that there is only one value - String value = request.getParameter(name); - - if (value == null) { - continue; - } - - if (isFirst) { - sb.append("?"); - isFirst = false; - } - sb.append(name).append("=").append(value); - - if (iterator.hasNext()) { - sb.append("&"); + try { + StringBuffer sb = request.getRequestURL(); + + Iterator iterator = returnToUrlParameters.iterator(); + boolean isFirst = true; + + while (iterator.hasNext()) { + String name = iterator.next(); + // Assume for simplicity that there is only one value + String value = request.getParameter(name); + + if (value == null) { + continue; + } + + if (isFirst) { + sb.append("?"); + isFirst = false; + } + sb.append(URLEncoder.encode(name, "UTF-8")).append("=").append(URLEncoder.encode(value, "UTF-8")); + + if (iterator.hasNext()) { + sb.append("&"); + } } + return sb.toString(); + } catch(UnsupportedEncodingException e) { + Error err = new AssertionError("The Java platform guarantees UTF-8 support, but it seemingly is not present."); + err.initCause(e); + throw err; } - - return sb.toString(); } /** diff --git a/openid/src/test/java/org/springframework/security/openid/OpenIDAuthenticationFilterTests.java b/openid/src/test/java/org/springframework/security/openid/OpenIDAuthenticationFilterTests.java index 54efad432a..80adce08b3 100644 --- a/openid/src/test/java/org/springframework/security/openid/OpenIDAuthenticationFilterTests.java +++ b/openid/src/test/java/org/springframework/security/openid/OpenIDAuthenticationFilterTests.java @@ -4,6 +4,9 @@ import static org.junit.Assert.assertEquals; import static org.mockito.Matchers.any; import static org.mockito.Mockito.*; +import java.net.URI; +import java.util.Collections; + import javax.servlet.FilterChain; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; @@ -63,4 +66,35 @@ public class OpenIDAuthenticationFilterTests { // Filter chain shouldn't proceed verify(fc, never()).doFilter(any(HttpServletRequest.class), any(HttpServletResponse.class)); } + + /** + * Tests that the filter encodes any query parameters on the return_to URL. + */ + @Test + public void encodesUrlParameters() throws Exception { + // Arbitrary parameter name and value that will both need to be encoded: + String paramName = "foo&bar"; + String paramValue = "http://example.com/path?a=b&c=d"; + MockHttpServletRequest req = new MockHttpServletRequest("GET", REQUEST_PATH); + req.addParameter(paramName, paramValue); + filter.setReturnToUrlParameters(Collections.singleton(paramName)); + + URI returnTo = new URI(filter.buildReturnToUrl(req)); + String query = returnTo.getRawQuery(); + assertEquals(1, count(query, '=')); + assertEquals(0, count(query, '&')); + } + + /** + * Counts the number of occurrences of {@code c} in {@code s}. + */ + private static int count(String s, char c) { + int count = 0; + for(char ch : s.toCharArray()) { + if(c == ch) { + count += 1; + } + } + return count; + } }