mirror of https://github.com/apache/kafka.git
KAFKA-16305: Avoid optimisation in handshakeUnwrap (#15434)
Performs additional unwrap during handshake after data from client is processed to support openssl, which needs the extra unwrap to complete handshake. Reviewers: Ismael Juma <ismael@juma.me.uk>, Rajini Sivaram <rajinisivaram@googlemail.com>
This commit is contained in:
parent
0e46be5bf0
commit
633d2f139c
|
@ -506,13 +506,14 @@ public class SslTransportLayer implements TransportLayer {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Perform handshake unwrap
|
* Perform handshake unwrap.
|
||||||
|
* Visible for testing.
|
||||||
* @param doRead boolean If true, read more from the socket channel
|
* @param doRead boolean If true, read more from the socket channel
|
||||||
* @param ignoreHandshakeStatus If true, continue to unwrap if data available regardless of handshake status
|
* @param ignoreHandshakeStatus If true, continue to unwrap if data available regardless of handshake status
|
||||||
* @return SSLEngineResult
|
* @return SSLEngineResult
|
||||||
* @throws IOException
|
* @throws IOException
|
||||||
*/
|
*/
|
||||||
private SSLEngineResult handshakeUnwrap(boolean doRead, boolean ignoreHandshakeStatus) throws IOException {
|
SSLEngineResult handshakeUnwrap(boolean doRead, boolean ignoreHandshakeStatus) throws IOException {
|
||||||
log.trace("SSLHandshake handshakeUnwrap {}", channelId);
|
log.trace("SSLHandshake handshakeUnwrap {}", channelId);
|
||||||
SSLEngineResult result;
|
SSLEngineResult result;
|
||||||
int read = 0;
|
int read = 0;
|
||||||
|
@ -534,7 +535,7 @@ public class SslTransportLayer implements TransportLayer {
|
||||||
handshakeStatus == HandshakeStatus.NEED_UNWRAP) ||
|
handshakeStatus == HandshakeStatus.NEED_UNWRAP) ||
|
||||||
(ignoreHandshakeStatus && netReadBuffer.position() != position);
|
(ignoreHandshakeStatus && netReadBuffer.position() != position);
|
||||||
log.trace("SSLHandshake handshakeUnwrap: handshakeStatus {} status {}", handshakeStatus, result.getStatus());
|
log.trace("SSLHandshake handshakeUnwrap: handshakeStatus {} status {}", handshakeStatus, result.getStatus());
|
||||||
} while (netReadBuffer.position() != 0 && cont);
|
} while (cont);
|
||||||
|
|
||||||
// Throw EOF exception for failed read after processing already received data
|
// Throw EOF exception for failed read after processing already received data
|
||||||
// so that handshake failures are reported correctly
|
// so that handshake failures are reported correctly
|
||||||
|
|
|
@ -53,6 +53,7 @@ import java.nio.ByteBuffer;
|
||||||
import java.nio.channels.Channels;
|
import java.nio.channels.Channels;
|
||||||
import java.nio.channels.SelectionKey;
|
import java.nio.channels.SelectionKey;
|
||||||
import java.nio.channels.SocketChannel;
|
import java.nio.channels.SocketChannel;
|
||||||
|
import java.nio.charset.StandardCharsets;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.Arrays;
|
import java.util.Arrays;
|
||||||
import java.util.Collection;
|
import java.util.Collection;
|
||||||
|
@ -67,14 +68,17 @@ import java.util.stream.Stream;
|
||||||
|
|
||||||
import javax.net.ssl.SSLContext;
|
import javax.net.ssl.SSLContext;
|
||||||
import javax.net.ssl.SSLEngine;
|
import javax.net.ssl.SSLEngine;
|
||||||
|
import javax.net.ssl.SSLEngineResult;
|
||||||
import javax.net.ssl.SSLException;
|
import javax.net.ssl.SSLException;
|
||||||
import javax.net.ssl.SSLParameters;
|
import javax.net.ssl.SSLParameters;
|
||||||
|
import javax.net.ssl.SSLSession;
|
||||||
|
|
||||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||||
import static org.junit.jupiter.api.Assertions.assertFalse;
|
import static org.junit.jupiter.api.Assertions.assertFalse;
|
||||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||||
import static org.junit.jupiter.api.Assertions.assertTrue;
|
import static org.junit.jupiter.api.Assertions.assertTrue;
|
||||||
import static org.junit.jupiter.api.Assumptions.assumeTrue;
|
import static org.junit.jupiter.api.Assumptions.assumeTrue;
|
||||||
|
import static org.mockito.ArgumentMatchers.any;
|
||||||
import static org.mockito.Mockito.doThrow;
|
import static org.mockito.Mockito.doThrow;
|
||||||
import static org.mockito.Mockito.mock;
|
import static org.mockito.Mockito.mock;
|
||||||
import static org.mockito.Mockito.times;
|
import static org.mockito.Mockito.times;
|
||||||
|
@ -1477,6 +1481,57 @@ public class SslTransportLayerTest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* SSLEngine implementations may transition from NEED_UNWRAP to NEED_UNWRAP
|
||||||
|
* even after reading all the data from the socket. This test ensures we
|
||||||
|
* continue unwrapping and not break early.
|
||||||
|
* Please refer <a href="https://issues.apache.org/jira/browse/KAFKA-16305">KAFKA-16305</a>
|
||||||
|
* for more information.
|
||||||
|
*/
|
||||||
|
@Test
|
||||||
|
public void testHandshakeUnwrapContinuesUnwrappingOnNeedUnwrapAfterAllBytesRead() throws IOException {
|
||||||
|
// Given
|
||||||
|
byte[] data = "ClientHello?".getBytes(StandardCharsets.UTF_8);
|
||||||
|
|
||||||
|
SSLEngine sslEngine = mock(SSLEngine.class);
|
||||||
|
SocketChannel socketChannel = mock(SocketChannel.class);
|
||||||
|
SelectionKey selectionKey = mock(SelectionKey.class);
|
||||||
|
when(selectionKey.channel()).thenReturn(socketChannel);
|
||||||
|
SSLSession sslSession = mock(SSLSession.class);
|
||||||
|
SslTransportLayer sslTransportLayer = new SslTransportLayer(
|
||||||
|
"test-channel",
|
||||||
|
selectionKey,
|
||||||
|
sslEngine,
|
||||||
|
mock(ChannelMetadataRegistry.class)
|
||||||
|
);
|
||||||
|
|
||||||
|
when(sslEngine.getSession()).thenReturn(sslSession);
|
||||||
|
when(sslSession.getPacketBufferSize()).thenReturn(data.length * 2);
|
||||||
|
sslTransportLayer.startHandshake(); // to initialize the buffers
|
||||||
|
|
||||||
|
ByteBuffer netReadBuffer = sslTransportLayer.netReadBuffer();
|
||||||
|
netReadBuffer.clear();
|
||||||
|
ByteBuffer appReadBuffer = sslTransportLayer.appReadBuffer();
|
||||||
|
when(socketChannel.read(any(ByteBuffer.class))).then(invocation -> {
|
||||||
|
((ByteBuffer) invocation.getArgument(0)).put(data);
|
||||||
|
return data.length;
|
||||||
|
});
|
||||||
|
|
||||||
|
when(sslEngine.unwrap(netReadBuffer, appReadBuffer))
|
||||||
|
.thenAnswer(invocation -> {
|
||||||
|
netReadBuffer.flip();
|
||||||
|
return new SSLEngineResult(SSLEngineResult.Status.OK, SSLEngineResult.HandshakeStatus.NEED_UNWRAP, data.length, 0);
|
||||||
|
}).thenReturn(new SSLEngineResult(SSLEngineResult.Status.OK, SSLEngineResult.HandshakeStatus.NEED_WRAP, 0, 0));
|
||||||
|
|
||||||
|
// When
|
||||||
|
SSLEngineResult result = sslTransportLayer.handshakeUnwrap(true, false);
|
||||||
|
|
||||||
|
// Then
|
||||||
|
verify(sslEngine, times(2)).unwrap(netReadBuffer, appReadBuffer);
|
||||||
|
assertEquals(SSLEngineResult.Status.OK, result.getStatus());
|
||||||
|
assertEquals(SSLEngineResult.HandshakeStatus.NEED_WRAP, result.getHandshakeStatus());
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testSSLEngineCloseInboundInvokedOnClose() throws IOException {
|
public void testSSLEngineCloseInboundInvokedOnClose() throws IOException {
|
||||||
// Given
|
// Given
|
||||||
|
|
Loading…
Reference in New Issue