KAFKA-16305: Avoid optimisation in handshakeUnwrap (#15434)

Performs additional unwrap during handshake after data from client is processed to support openssl, which needs the extra unwrap to complete handshake.

Reviewers: Ismael Juma <ismael@juma.me.uk>, Rajini Sivaram <rajinisivaram@googlemail.com>
This commit is contained in:
Gaurav Narula 2024-02-28 09:37:58 +00:00 committed by Chia-Ping Tsai
parent 0e46be5bf0
commit 633d2f139c
2 changed files with 59 additions and 3 deletions

View File

@ -506,13 +506,14 @@ public class SslTransportLayer implements TransportLayer {
}
/**
* Perform handshake unwrap
* Perform handshake unwrap.
* Visible for testing.
* @param doRead boolean If true, read more from the socket channel
* @param ignoreHandshakeStatus If true, continue to unwrap if data available regardless of handshake status
* @return SSLEngineResult
* @throws IOException
*/
private SSLEngineResult handshakeUnwrap(boolean doRead, boolean ignoreHandshakeStatus) throws IOException {
SSLEngineResult handshakeUnwrap(boolean doRead, boolean ignoreHandshakeStatus) throws IOException {
log.trace("SSLHandshake handshakeUnwrap {}", channelId);
SSLEngineResult result;
int read = 0;
@ -534,7 +535,7 @@ public class SslTransportLayer implements TransportLayer {
handshakeStatus == HandshakeStatus.NEED_UNWRAP) ||
(ignoreHandshakeStatus && netReadBuffer.position() != position);
log.trace("SSLHandshake handshakeUnwrap: handshakeStatus {} status {}", handshakeStatus, result.getStatus());
} while (netReadBuffer.position() != 0 && cont);
} while (cont);
// Throw EOF exception for failed read after processing already received data
// so that handshake failures are reported correctly

View File

@ -53,6 +53,7 @@ import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
@ -67,14 +68,17 @@ import java.util.stream.Stream;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLSession;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assumptions.assumeTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
@ -1477,6 +1481,57 @@ public class SslTransportLayerTest {
}
}
/**
* SSLEngine implementations may transition from NEED_UNWRAP to NEED_UNWRAP
* even after reading all the data from the socket. This test ensures we
* continue unwrapping and not break early.
* Please refer <a href="https://issues.apache.org/jira/browse/KAFKA-16305">KAFKA-16305</a>
* for more information.
*/
@Test
public void testHandshakeUnwrapContinuesUnwrappingOnNeedUnwrapAfterAllBytesRead() throws IOException {
// Given
byte[] data = "ClientHello?".getBytes(StandardCharsets.UTF_8);
SSLEngine sslEngine = mock(SSLEngine.class);
SocketChannel socketChannel = mock(SocketChannel.class);
SelectionKey selectionKey = mock(SelectionKey.class);
when(selectionKey.channel()).thenReturn(socketChannel);
SSLSession sslSession = mock(SSLSession.class);
SslTransportLayer sslTransportLayer = new SslTransportLayer(
"test-channel",
selectionKey,
sslEngine,
mock(ChannelMetadataRegistry.class)
);
when(sslEngine.getSession()).thenReturn(sslSession);
when(sslSession.getPacketBufferSize()).thenReturn(data.length * 2);
sslTransportLayer.startHandshake(); // to initialize the buffers
ByteBuffer netReadBuffer = sslTransportLayer.netReadBuffer();
netReadBuffer.clear();
ByteBuffer appReadBuffer = sslTransportLayer.appReadBuffer();
when(socketChannel.read(any(ByteBuffer.class))).then(invocation -> {
((ByteBuffer) invocation.getArgument(0)).put(data);
return data.length;
});
when(sslEngine.unwrap(netReadBuffer, appReadBuffer))
.thenAnswer(invocation -> {
netReadBuffer.flip();
return new SSLEngineResult(SSLEngineResult.Status.OK, SSLEngineResult.HandshakeStatus.NEED_UNWRAP, data.length, 0);
}).thenReturn(new SSLEngineResult(SSLEngineResult.Status.OK, SSLEngineResult.HandshakeStatus.NEED_WRAP, 0, 0));
// When
SSLEngineResult result = sslTransportLayer.handshakeUnwrap(true, false);
// Then
verify(sslEngine, times(2)).unwrap(netReadBuffer, appReadBuffer);
assertEquals(SSLEngineResult.Status.OK, result.getStatus());
assertEquals(SSLEngineResult.HandshakeStatus.NEED_WRAP, result.getHandshakeStatus());
}
@Test
public void testSSLEngineCloseInboundInvokedOnClose() throws IOException {
// Given