Ensure WebSocketHttpRequestHandler writes headers
Closes gh-23179
This commit is contained in:
parent
6e79dcdc8e
commit
5af9a8edae
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2018 the original author or authors.
|
||||
* Copyright 2002-2019 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
@ -166,7 +166,6 @@ public class WebSocketHttpRequestHandler implements HttpRequestHandler, Lifecycl
|
|||
}
|
||||
this.handshakeHandler.doHandshake(request, response, this.wsHandler, attributes);
|
||||
chain.applyAfterHandshake(request, response, null);
|
||||
response.close();
|
||||
}
|
||||
catch (HandshakeFailureException ex) {
|
||||
failure = ex;
|
||||
|
|
@ -177,8 +176,10 @@ public class WebSocketHttpRequestHandler implements HttpRequestHandler, Lifecycl
|
|||
finally {
|
||||
if (failure != null) {
|
||||
chain.applyAfterHandshake(request, response, failure);
|
||||
response.close();
|
||||
throw failure;
|
||||
}
|
||||
response.close();
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
/*
|
||||
* Copyright 2002-2018 the original author or authors.
|
||||
* Copyright 2002-2019 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
@ -26,6 +26,7 @@ import org.junit.Test;
|
|||
import org.mockito.Mock;
|
||||
import org.mockito.MockitoAnnotations;
|
||||
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.web.socket.AbstractHttpRequestTests;
|
||||
import org.springframework.web.socket.SubProtocolCapable;
|
||||
import org.springframework.web.socket.WebSocketExtension;
|
||||
|
|
@ -62,14 +63,9 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests {
|
|||
public void supportedSubProtocols() {
|
||||
this.handshakeHandler.setSupportedProtocols("stomp", "mqtt");
|
||||
given(this.upgradeStrategy.getSupportedVersions()).willReturn(new String[] {"13"});
|
||||
this.servletRequest.setMethod("GET");
|
||||
|
||||
WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders());
|
||||
headers.setUpgrade("WebSocket");
|
||||
headers.setConnection("Upgrade");
|
||||
headers.setSecWebSocketVersion("13");
|
||||
headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw==");
|
||||
headers.setSecWebSocketProtocol("STOMP");
|
||||
this.servletRequest.setMethod("GET");
|
||||
initHeaders(this.request.getHeaders()).setSecWebSocketProtocol("STOMP");
|
||||
|
||||
WebSocketHandler handler = new TextWebSocketHandler();
|
||||
Map<String, Object> attributes = Collections.emptyMap();
|
||||
|
|
@ -88,16 +84,10 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests {
|
|||
given(this.upgradeStrategy.getSupportedExtensions(this.request)).willReturn(Collections.singletonList(extension1));
|
||||
|
||||
this.servletRequest.setMethod("GET");
|
||||
|
||||
WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders());
|
||||
headers.setUpgrade("WebSocket");
|
||||
headers.setConnection("Upgrade");
|
||||
headers.setSecWebSocketVersion("13");
|
||||
headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw==");
|
||||
headers.setSecWebSocketExtensions(Arrays.asList(extension1, extension2));
|
||||
initHeaders(this.request.getHeaders()).setSecWebSocketExtensions(Arrays.asList(extension1, extension2));
|
||||
|
||||
WebSocketHandler handler = new TextWebSocketHandler();
|
||||
Map<String, Object> attributes = Collections.<String, Object>emptyMap();
|
||||
Map<String, Object> attributes = Collections.emptyMap();
|
||||
this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes);
|
||||
|
||||
verify(this.upgradeStrategy).upgrade(this.request, this.response, null,
|
||||
|
|
@ -109,16 +99,10 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests {
|
|||
given(this.upgradeStrategy.getSupportedVersions()).willReturn(new String[] {"13"});
|
||||
|
||||
this.servletRequest.setMethod("GET");
|
||||
|
||||
WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders());
|
||||
headers.setUpgrade("WebSocket");
|
||||
headers.setConnection("Upgrade");
|
||||
headers.setSecWebSocketVersion("13");
|
||||
headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw==");
|
||||
headers.setSecWebSocketProtocol("v11.stomp");
|
||||
initHeaders(this.request.getHeaders()).setSecWebSocketProtocol("v11.stomp");
|
||||
|
||||
WebSocketHandler handler = new SubProtocolCapableHandler("v12.stomp", "v11.stomp");
|
||||
Map<String, Object> attributes = Collections.<String, Object>emptyMap();
|
||||
Map<String, Object> attributes = Collections.emptyMap();
|
||||
this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes);
|
||||
|
||||
verify(this.upgradeStrategy).upgrade(this.request, this.response, "v11.stomp",
|
||||
|
|
@ -130,22 +114,25 @@ public class DefaultHandshakeHandlerTests extends AbstractHttpRequestTests {
|
|||
given(this.upgradeStrategy.getSupportedVersions()).willReturn(new String[] {"13"});
|
||||
|
||||
this.servletRequest.setMethod("GET");
|
||||
|
||||
WebSocketHttpHeaders headers = new WebSocketHttpHeaders(this.request.getHeaders());
|
||||
headers.setUpgrade("WebSocket");
|
||||
headers.setConnection("Upgrade");
|
||||
headers.setSecWebSocketVersion("13");
|
||||
headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw==");
|
||||
headers.setSecWebSocketProtocol("v10.stomp");
|
||||
initHeaders(this.request.getHeaders()).setSecWebSocketProtocol("v10.stomp");
|
||||
|
||||
WebSocketHandler handler = new SubProtocolCapableHandler("v12.stomp", "v11.stomp");
|
||||
Map<String, Object> attributes = Collections.<String, Object>emptyMap();
|
||||
Map<String, Object> attributes = Collections.emptyMap();
|
||||
this.handshakeHandler.doHandshake(this.request, this.response, handler, attributes);
|
||||
|
||||
verify(this.upgradeStrategy).upgrade(this.request, this.response, null,
|
||||
Collections.emptyList(), null, handler, attributes);
|
||||
}
|
||||
|
||||
private WebSocketHttpHeaders initHeaders(HttpHeaders httpHeaders) {
|
||||
WebSocketHttpHeaders headers = new WebSocketHttpHeaders(httpHeaders);
|
||||
headers.setUpgrade("WebSocket");
|
||||
headers.setConnection("Upgrade");
|
||||
headers.setSecWebSocketVersion("13");
|
||||
headers.setSecWebSocketKey("82/ZS2YHjEnUN97HLL8tbw==");
|
||||
return headers;
|
||||
}
|
||||
|
||||
|
||||
private static class SubProtocolCapableHandler extends TextWebSocketHandler implements SubProtocolCapable {
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,141 @@
|
|||
/*
|
||||
* Copyright 2002-2019 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* https://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.springframework.web.socket.server.support;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Collections;
|
||||
import java.util.Map;
|
||||
import javax.servlet.ServletException;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
|
||||
import org.springframework.http.server.ServerHttpRequest;
|
||||
import org.springframework.http.server.ServerHttpResponse;
|
||||
import org.springframework.mock.web.test.MockHttpServletRequest;
|
||||
import org.springframework.mock.web.test.MockHttpServletResponse;
|
||||
import org.springframework.web.socket.WebSocketHandler;
|
||||
import org.springframework.web.socket.server.HandshakeFailureException;
|
||||
import org.springframework.web.socket.server.HandshakeHandler;
|
||||
import org.springframework.web.socket.server.HandshakeInterceptor;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertSame;
|
||||
import static org.junit.Assert.fail;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.verifyNoMoreInteractions;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
/**
|
||||
* Unit tests for {@link WebSocketHttpRequestHandler}.
|
||||
* @author Rossen Stoyanchev
|
||||
* @since 5.1.9
|
||||
*/
|
||||
public class WebSocketHttpRequestHandlerTests {
|
||||
|
||||
private HandshakeHandler handshakeHandler;
|
||||
|
||||
private WebSocketHttpRequestHandler requestHandler;
|
||||
|
||||
private MockHttpServletResponse response;
|
||||
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
this.handshakeHandler = mock(HandshakeHandler.class);
|
||||
this.requestHandler = new WebSocketHttpRequestHandler(mock(WebSocketHandler.class), this.handshakeHandler);
|
||||
this.response = new MockHttpServletResponse();
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void success() throws ServletException, IOException {
|
||||
TestInterceptor interceptor = new TestInterceptor(true);
|
||||
this.requestHandler.setHandshakeInterceptors(Collections.singletonList(interceptor));
|
||||
this.requestHandler.handleRequest(new MockHttpServletRequest(), this.response);
|
||||
|
||||
verify(this.handshakeHandler).doHandshake(any(), any(), any(), any());
|
||||
assertEquals("headerValue", this.response.getHeader("headerName"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void failure() throws ServletException, IOException {
|
||||
TestInterceptor interceptor = new TestInterceptor(true);
|
||||
this.requestHandler.setHandshakeInterceptors(Collections.singletonList(interceptor));
|
||||
|
||||
when(this.handshakeHandler.doHandshake(any(), any(), any(), any()))
|
||||
.thenThrow(new IllegalStateException("bad state"));
|
||||
|
||||
try {
|
||||
this.requestHandler.handleRequest(new MockHttpServletRequest(), this.response);
|
||||
fail();
|
||||
}
|
||||
catch (HandshakeFailureException ex) {
|
||||
assertSame(ex, interceptor.getException());
|
||||
assertEquals("headerValue", this.response.getHeader("headerName"));
|
||||
assertEquals("exceptionHeaderValue", this.response.getHeader("exceptionHeaderName"));
|
||||
}
|
||||
}
|
||||
|
||||
@Test // gh-23179
|
||||
public void handshakeNotAllowed() throws ServletException, IOException {
|
||||
TestInterceptor interceptor = new TestInterceptor(false);
|
||||
this.requestHandler.setHandshakeInterceptors(Collections.singletonList(interceptor));
|
||||
|
||||
this.requestHandler.handleRequest(new MockHttpServletRequest(), this.response);
|
||||
|
||||
verifyNoMoreInteractions(this.handshakeHandler);
|
||||
assertEquals("headerValue", this.response.getHeader("headerName"));
|
||||
}
|
||||
|
||||
|
||||
private static class TestInterceptor implements HandshakeInterceptor {
|
||||
|
||||
private final boolean allowHandshake;
|
||||
|
||||
private Exception exception;
|
||||
|
||||
|
||||
private TestInterceptor(boolean allowHandshake) {
|
||||
this.allowHandshake = allowHandshake;
|
||||
}
|
||||
|
||||
|
||||
public Exception getException() {
|
||||
return this.exception;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response,
|
||||
WebSocketHandler wsHandler, Map<String, Object> attributes) {
|
||||
|
||||
response.getHeaders().add("headerName", "headerValue");
|
||||
return this.allowHandshake;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response,
|
||||
WebSocketHandler wsHandler, Exception exception) {
|
||||
|
||||
response.getHeaders().add("exceptionHeaderName", "exceptionHeaderValue");
|
||||
this.exception = exception;
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
Loading…
Reference in New Issue