Adapt to HtmlUnit 2.25 getCharset() return type at runtime

Issue: SPR-15319
This commit is contained in:
Juergen Hoeller 2017-03-07 10:33:26 +01:00
parent 015e00b5dd
commit 9de97614a0
1 changed files with 36 additions and 25 deletions

View File

@ -1,5 +1,5 @@
/*
* Copyright 2002-2016 the original author or authors.
* Copyright 2002-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -17,8 +17,10 @@
package org.springframework.test.web.servlet.htmlunit;
import java.io.UnsupportedEncodingException;
import java.lang.reflect.Method;
import java.net.URL;
import java.net.URLDecoder;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Enumeration;
import java.util.List;
@ -50,7 +52,9 @@ import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilde
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
import org.springframework.test.web.servlet.request.RequestPostProcessor;
import org.springframework.util.Assert;
import org.springframework.util.ClassUtils;
import org.springframework.util.ObjectUtils;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.util.UriComponents;
import org.springframework.web.util.UriComponentsBuilder;
@ -59,7 +63,7 @@ import org.springframework.web.util.UriComponentsBuilder;
* Internal class used to transform a {@link WebRequest} into a
* {@link MockHttpServletRequest} using Spring MVC Test's {@link RequestBuilder}.
*
* <p>By default the first path segment of the URL is used as the contextPath.
* <p>By default the first path segment of the URL is used as the context path.
* To override this default see {@link #setContextPath(String)}.
*
* @author Rob Winch
@ -71,6 +75,11 @@ final class HtmlUnitRequestBuilder implements RequestBuilder, Mergeable {
private static final Pattern LOCALE_PATTERN = Pattern.compile("^\\s*(\\w{2})(?:-(\\w{2}))?(?:;q=(\\d+\\.\\d+))?$");
private static final Charset DEFAULT_CHARSET = Charset.forName("ISO-8859-1");
private static final Method getCharsetMethod = ClassUtils.getMethodIfAvailable(WebRequest.class, "getCharset");
private final Map<String, MockHttpSession> sessions;
private final WebClient webClient;
@ -98,23 +107,23 @@ final class HtmlUnitRequestBuilder implements RequestBuilder, Mergeable {
Assert.notNull(sessions, "Sessions Map must not be null");
Assert.notNull(webClient, "WebClient must not be null");
Assert.notNull(webRequest, "WebRequest must not be null");
this.sessions = sessions;
this.webClient = webClient;
this.webRequest = webRequest;
}
public MockHttpServletRequest buildRequest(ServletContext servletContext) {
String charset = getCharset();
Charset charset = getCharset();
String httpMethod = this.webRequest.getHttpMethod().name();
UriComponents uriComponents = uriComponents();
MockHttpServletRequest request = new HtmlUnitMockHttpServletRequest(
servletContext, httpMethod, uriComponents.getPath());
parent(request, this.parentBuilder);
request.setServerName(uriComponents.getHost()); // needs to be first for additional headers
request.setServerName(uriComponents.getHost()); // needs to be first for additional headers
authType(request);
request.setCharacterEncoding(charset);
request.setCharacterEncoding(charset.name());
content(request, charset);
contextPath(request, uriComponents);
contentType(request);
@ -132,6 +141,21 @@ final class HtmlUnitRequestBuilder implements RequestBuilder, Mergeable {
return postProcess(request);
}
private Charset getCharset() {
if (getCharsetMethod != null) {
Object value = ReflectionUtils.invokeMethod(getCharsetMethod, this.webRequest);
if (value instanceof Charset) {
// HtmlUnit 2.25: a Charset
return (Charset) value;
}
else if (value != null) {
// HtmlUnit up until 2.24: a String
return Charset.forName(value.toString());
}
}
return DEFAULT_CHARSET;
}
private MockHttpServletRequest postProcess(MockHttpServletRequest request) {
if (this.parentPostProcessor != null) {
request = this.parentPostProcessor.postProcessRequest(request);
@ -220,17 +244,12 @@ final class HtmlUnitRequestBuilder implements RequestBuilder, Mergeable {
}
}
private void content(MockHttpServletRequest request, String charset) {
private void content(MockHttpServletRequest request, Charset charset) {
String requestBody = this.webRequest.getRequestBody();
if (requestBody == null) {
return;
}
try {
request.setContent(requestBody.getBytes(charset));
}
catch (UnsupportedEncodingException ex) {
throw new IllegalStateException(ex);
}
request.setContent(requestBody.getBytes(charset));
}
private void contentType(MockHttpServletRequest request) {
@ -256,8 +275,8 @@ final class HtmlUnitRequestBuilder implements RequestBuilder, Mergeable {
}
else {
if (!uriComponents.getPath().startsWith(this.contextPath)) {
throw new IllegalArgumentException(uriComponents.getPath() + " should start with contextPath " +
this.contextPath);
throw new IllegalArgumentException("\"" + uriComponents.getPath() +
"\" should start with context path \"" + this.contextPath + "\"");
}
request.setContextPath(this.contextPath);
}
@ -273,7 +292,7 @@ final class HtmlUnitRequestBuilder implements RequestBuilder, Mergeable {
String cookieName = tokens.nextToken().trim();
if (!tokens.hasMoreTokens()) {
throw new IllegalArgumentException("Expected value for cookie name '" + cookieName +
"'. Full cookie was " + cookieHeaderValue);
"': full cookie header was [" + cookieHeaderValue + "]");
}
String cookieValue = tokens.nextToken().trim();
processCookie(request, cookies, new Cookie(cookieName, cookieValue));
@ -305,14 +324,6 @@ final class HtmlUnitRequestBuilder implements RequestBuilder, Mergeable {
}
}
private String getCharset() {
String charset = this.webRequest.getCharset();
if (charset == null) {
return "ISO-8859-1";
}
return charset;
}
private String header(String headerName) {
return this.webRequest.getAdditionalHeaders().get(headerName);
}
@ -394,7 +405,7 @@ final class HtmlUnitRequestBuilder implements RequestBuilder, Mergeable {
private Locale parseLocale(String locale) {
Matcher matcher = LOCALE_PATTERN.matcher(locale);
if (!matcher.matches()) {
throw new IllegalArgumentException("Invalid locale " + locale);
throw new IllegalArgumentException("Invalid locale value [" + locale + "]");
}
String language = matcher.group(1);
String country = matcher.group(2);