diff --git a/clients/src/main/java/org/apache/kafka/common/network/Authenticator.java b/clients/src/main/java/org/apache/kafka/common/network/Authenticator.java index bcb011e830b..ac80d046fc2 100644 --- a/clients/src/main/java/org/apache/kafka/common/network/Authenticator.java +++ b/clients/src/main/java/org/apache/kafka/common/network/Authenticator.java @@ -21,8 +21,7 @@ import org.apache.kafka.common.security.auth.KafkaPrincipal; import java.io.Closeable; import java.io.IOException; -import java.util.Collections; -import java.util.List; +import java.util.Optional; /** * Authentication for Channel @@ -131,19 +130,21 @@ public interface Authenticator extends Closeable { } /** - * Return the (always non-null but possibly empty) client-side - * {@link NetworkReceive} responses that arrived during re-authentication that - * are unrelated to re-authentication, if any. These correspond to requests sent + * Return the next (always non-null but possibly empty) client-side + * {@link NetworkReceive} response that arrived during re-authentication that + * is unrelated to re-authentication, if any. These correspond to requests sent * prior to the beginning of re-authentication; the requests were made when the * channel was successfully authenticated, and the responses arrived during the - * re-authentication process. + * re-authentication process. The response returned is removed from the authenticator's + * queue. Responses of requests sent after completion of re-authentication are + * processed only when the authenticator response queue is empty. * * @return the (always non-null but possibly empty) client-side - * {@link NetworkReceive} responses that arrived during - * re-authentication that are unrelated to re-authentication, if any + * {@link NetworkReceive} response that arrived during + * re-authentication that is unrelated to re-authentication, if any */ - default List getAndClearResponsesReceivedDuringReauthentication() { - return Collections.emptyList(); + default Optional pollResponseReceivedDuringReauthentication() { + return Optional.empty(); } /** diff --git a/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java b/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java index 8d9465c4350..4e4edd47adb 100644 --- a/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java +++ b/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java @@ -28,8 +28,8 @@ import java.net.Socket; import java.net.SocketAddress; import java.nio.channels.SelectionKey; import java.nio.channels.SocketChannel; -import java.util.List; import java.util.Objects; +import java.util.Optional; import java.util.function.Supplier; /** @@ -631,18 +631,18 @@ public class KafkaChannel implements AutoCloseable { /** * Return the (always non-null but possibly empty) client-side - * {@link NetworkReceive} responses that arrived during re-authentication that - * are unrelated to re-authentication, if any. These correspond to requests sent - * prior to the beginning of re-authentication; the requests were made when the - * channel was successfully authenticated, and the responses arrived during the + * {@link NetworkReceive} response that arrived during re-authentication but + * is unrelated to re-authentication. This corresponds to a request sent + * prior to the beginning of re-authentication; the request was made when the + * channel was successfully authenticated, and the response arrived during the * re-authentication process. * - * @return the (always non-null but possibly empty) client-side - * {@link NetworkReceive} responses that arrived during - * re-authentication that are unrelated to re-authentication, if any + * @return client-side {@link NetworkReceive} response that arrived during + * re-authentication that is unrelated to re-authentication. This may + * be empty. */ - public List getAndClearResponsesReceivedDuringReauthentication() { - return authenticator.getAndClearResponsesReceivedDuringReauthentication(); + public Optional pollResponseReceivedDuringReauthentication() { + return authenticator.pollResponseReceivedDuringReauthentication(); } /** diff --git a/clients/src/main/java/org/apache/kafka/common/network/PlaintextChannelBuilder.java b/clients/src/main/java/org/apache/kafka/common/network/PlaintextChannelBuilder.java index 705e10f0cff..98a0c74d80d 100644 --- a/clients/src/main/java/org/apache/kafka/common/network/PlaintextChannelBuilder.java +++ b/clients/src/main/java/org/apache/kafka/common/network/PlaintextChannelBuilder.java @@ -26,6 +26,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.Closeable; +import java.io.IOException; import java.net.InetAddress; import java.nio.channels.SelectionKey; import java.util.Map; @@ -52,7 +53,7 @@ public class PlaintextChannelBuilder implements ChannelBuilder { public KafkaChannel buildChannel(String id, SelectionKey key, int maxReceiveSize, MemoryPool memoryPool, ChannelMetadataRegistry metadataRegistry) throws KafkaException { try { - PlaintextTransportLayer transportLayer = new PlaintextTransportLayer(key); + PlaintextTransportLayer transportLayer = buildTransportLayer(key); Supplier authenticatorCreator = () -> new PlaintextAuthenticator(configs, transportLayer, listenerName); return new KafkaChannel(id, transportLayer, authenticatorCreator, maxReceiveSize, memoryPool != null ? memoryPool : MemoryPool.NONE, metadataRegistry); @@ -62,6 +63,10 @@ public class PlaintextChannelBuilder implements ChannelBuilder { } } + protected PlaintextTransportLayer buildTransportLayer(SelectionKey key) throws IOException { + return new PlaintextTransportLayer(key); + } + @Override public void close() {} diff --git a/clients/src/main/java/org/apache/kafka/common/network/Selectable.java b/clients/src/main/java/org/apache/kafka/common/network/Selectable.java index 8f81dbeebe9..d799c91c570 100644 --- a/clients/src/main/java/org/apache/kafka/common/network/Selectable.java +++ b/clients/src/main/java/org/apache/kafka/common/network/Selectable.java @@ -19,6 +19,7 @@ package org.apache.kafka.common.network; import java.io.IOException; import java.net.InetSocketAddress; +import java.util.Collection; import java.util.List; import java.util.Map; @@ -76,9 +77,9 @@ public interface Selectable { List completedSends(); /** - * The list of receives that completed on the last {@link #poll(long) poll()} call. + * The collection of receives that completed on the last {@link #poll(long) poll()} call. */ - List completedReceives(); + Collection completedReceives(); /** * The connections that finished disconnecting on the last {@link #poll(long) poll()} diff --git a/clients/src/main/java/org/apache/kafka/common/network/Selector.java b/clients/src/main/java/org/apache/kafka/common/network/Selector.java index bd4f31ada50..cb91cad9257 100644 --- a/clients/src/main/java/org/apache/kafka/common/network/Selector.java +++ b/clients/src/main/java/org/apache/kafka/common/network/Selector.java @@ -41,17 +41,16 @@ import java.nio.channels.CancelledKeyException; import java.nio.channels.SelectionKey; import java.nio.channels.SocketChannel; import java.nio.channels.UnresolvedAddressException; -import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.Deque; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; @@ -91,7 +90,7 @@ public class Selector implements Selectable, AutoCloseable { public static final int NO_FAILED_AUTHENTICATION_DELAY = 0; private enum CloseMode { - GRACEFUL(true), // process outstanding staged receives, notify disconnect + GRACEFUL(true), // process outstanding buffered receives, notify disconnect NOTIFY_ONLY(true), // discard any outstanding receives, notify disconnect DISCARD_NO_NOTIFY(false); // discard any outstanding receives, no disconnect notification @@ -108,8 +107,7 @@ public class Selector implements Selectable, AutoCloseable { private final Set explicitlyMutedChannels; private boolean outOfMemory; private final List completedSends; - private final List completedReceives; - private final Map> stagedReceives; + private final LinkedHashMap completedReceives; private final Set immediatelyConnectedKeys; private final Map closingChannels; private Set keysWithBufferedRead; @@ -168,8 +166,7 @@ public class Selector implements Selectable, AutoCloseable { this.explicitlyMutedChannels = new HashSet<>(); this.outOfMemory = false; this.completedSends = new ArrayList<>(); - this.completedReceives = new ArrayList<>(); - this.stagedReceives = new HashMap<>(); + this.completedReceives = new LinkedHashMap<>(); this.immediatelyConnectedKeys = new HashSet<>(); this.closingChannels = new HashMap<>(); this.keysWithBufferedRead = new HashSet<>(); @@ -428,11 +425,10 @@ public class Selector implements Selectable, AutoCloseable { * This requires additional buffers to be maintained as we are reading from network, since the data on the wire is encrypted * we won't be able to read exact no.of bytes as kafka protocol requires. We read as many bytes as we can, up to SSLEngine's * application buffer size. This means we might be reading additional bytes than the requested size. - * If there is no further data to read from socketChannel selector won't invoke that channel and we've have additional bytes - * in the buffer. To overcome this issue we added "stagedReceives" map which contains per-channel deque. When we are - * reading a channel we read as many responses as we can and store them into "stagedReceives" and pop one response during - * the poll to add the completedReceives. If there are any active channels in the "stagedReceives" we set "timeout" to 0 - * and pop response and add to the completedReceives. + * If there is no further data to read from socketChannel selector won't invoke that channel and we have additional bytes + * in the buffer. To overcome this issue we added "keysWithBufferedRead" map which tracks channels which have data in the SSL + * buffers. If there are channels with buffered data that can by processed, we set "timeout" to 0 and process the data even + * if there is no more data to read from the socket. * * Atmost one entry is added to "completedReceives" for a channel in each poll. This is necessary to guarantee that * requests from a channel are processed on the broker in the order they are sent. Since outstanding requests added @@ -454,7 +450,7 @@ public class Selector implements Selectable, AutoCloseable { boolean dataInBuffers = !keysWithBufferedRead.isEmpty(); - if (hasStagedReceives() || !immediatelyConnectedKeys.isEmpty() || (madeReadProgressLastCall && dataInBuffers)) + if (!immediatelyConnectedKeys.isEmpty() || (madeReadProgressLastCall && dataInBuffers)) timeout = 0; if (!memoryPool.isOutOfMemory() && outOfMemory) { @@ -505,10 +501,6 @@ public class Selector implements Selectable, AutoCloseable { // we use the time at the end of select to ensure that we don't close any connections that // have just been processed in pollSelectionKeys maybeCloseOldestConnection(endSelect); - - // Add to completedReceives after closing expired connections to avoid removing - // channels with completed receives until all staged receives are completed. - addToCompletedReceives(); } /** @@ -572,12 +564,21 @@ public class Selector implements Selectable, AutoCloseable { log.debug("Successfully {}authenticated with {}", isReauthentication ? "re-" : "", channel.socketDescription()); } - List responsesReceivedDuringReauthentication = channel - .getAndClearResponsesReceivedDuringReauthentication(); - responsesReceivedDuringReauthentication.forEach(receive -> addToStagedReceives(channel, receive)); } + if (channel.ready() && channel.state() == ChannelState.NOT_CONNECTED) + channel.state(ChannelState.READY); + Optional responseReceivedDuringReauthentication = channel.pollResponseReceivedDuringReauthentication(); + responseReceivedDuringReauthentication.ifPresent(receive -> { + long currentTimeMs = time.milliseconds(); + addToCompletedReceives(channel, receive, currentTimeMs); + }); - attemptRead(key, channel); + //if channel is ready and has bytes to read from socket or buffer, and has no + //previous completed receive then read from it + if (channel.ready() && (key.isReadable() || channel.hasBytesBuffered()) && !hasCompletedReceive(channel) + && !explicitlyMutedChannels.contains(channel)) { + attemptRead(channel); + } if (channel.hasBytesBuffered()) { //this channel has bytes enqueued in intermediary buffers that we could not read @@ -671,37 +672,43 @@ public class Selector implements Selectable, AutoCloseable { } } - private void attemptRead(SelectionKey key, KafkaChannel channel) throws IOException { - //if channel is ready and has bytes to read from socket or buffer, and has no - //previous receive(s) already staged or otherwise in progress then read from it - if (channel.ready() && (key.isReadable() || channel.hasBytesBuffered()) && !hasStagedReceive(channel) - && !explicitlyMutedChannels.contains(channel)) { + private void attemptRead(KafkaChannel channel) throws IOException { + String nodeId = channel.id(); - String nodeId = channel.id(); + long bytesReceived = channel.read(); + if (bytesReceived != 0) { + long currentTimeMs = time.milliseconds(); + sensors.recordBytesReceived(nodeId, bytesReceived, currentTimeMs); + madeReadProgressLastPoll = true; - while (true) { - long bytesReceived = channel.read(); - if (bytesReceived == 0) - break; - - long currentTimeMs = time.milliseconds(); - sensors.recordBytesReceived(nodeId, bytesReceived, currentTimeMs); - madeReadProgressLastPoll = true; - - NetworkReceive receive = channel.maybeCompleteReceive(); - if (receive == null) - break; - - sensors.recordCompletedReceive(nodeId, receive.size(), currentTimeMs); - addToStagedReceives(channel, receive); - } - - if (channel.isMuted()) { - outOfMemory = true; //channel has muted itself due to memory pressure. - } else { - madeReadProgressLastPoll = true; + NetworkReceive receive = channel.maybeCompleteReceive(); + if (receive != null) { + addToCompletedReceives(channel, receive, currentTimeMs); } } + if (channel.isMuted()) { + outOfMemory = true; //channel has muted itself due to memory pressure. + } else { + madeReadProgressLastPoll = true; + } + } + + private boolean maybeReadFromClosingChannel(KafkaChannel channel) { + boolean hasPending; + if (channel.state().state() != ChannelState.State.READY) + hasPending = false; + else if (explicitlyMutedChannels.contains(channel) || hasCompletedReceive(channel)) + hasPending = true; + else { + try { + attemptRead(channel); + hasPending = hasCompletedReceive(channel); + } catch (Exception e) { + log.trace("Read from closing channel failed, ignoring exception", e); + hasPending = false; + } + } + return hasPending; } // Record time spent in pollSelectionKeys for channel (moved into a method to keep checkstyle happy) @@ -716,8 +723,8 @@ public class Selector implements Selectable, AutoCloseable { } @Override - public List completedReceives() { - return this.completedReceives; + public Collection completedReceives() { + return this.completedReceives.values(); } @Override @@ -805,12 +812,14 @@ public class Selector implements Selectable, AutoCloseable { this.connected.clear(); this.disconnected.clear(); - // Remove closed channels after all their staged receives have been processed or if a send was requested + // Remove closed channels after all their buffered receives have been processed or if a send was requested for (Iterator> it = closingChannels.entrySet().iterator(); it.hasNext(); ) { KafkaChannel channel = it.next().getValue(); - Deque deque = this.stagedReceives.get(channel); boolean sendFailed = failedSends.remove(channel.id()); - if (deque == null || deque.isEmpty() || sendFailed) { + boolean hasPending = false; + if (!sendFailed) + hasPending = maybeReadFromClosingChannel(channel); + if (!hasPending || sendFailed) { doClose(channel, true); it.remove(); } @@ -876,7 +885,7 @@ public class Selector implements Selectable, AutoCloseable { /** * Begin closing this connection. - * If 'closeMode' is `CloseMode.GRACEFUL`, the channel is disconnected here, but staged receives + * If 'closeMode' is `CloseMode.GRACEFUL`, the channel is disconnected here, but outstanding receives * are processed. The channel is closed when there are no outstanding receives or if a send is * requested. For other values of `closeMode`, outstanding receives are discarded and the channel * is closed immediately. @@ -897,9 +906,7 @@ public class Selector implements Selectable, AutoCloseable { // handle close(). When the remote end closes its connection, the channel is retained until // a send fails or all outstanding receives are processed. Mute state of disconnected channels // are tracked to ensure that requests are processed one-by-one by the broker to preserve ordering. - Deque deque = this.stagedReceives.get(channel); - if (closeMode == CloseMode.GRACEFUL && deque != null && !deque.isEmpty()) { - // stagedReceives will be moved to completedReceives later along with receives from other channels + if (closeMode == CloseMode.GRACEFUL && maybeReadFromClosingChannel(channel)) { closingChannels.put(channel.id(), channel); log.debug("Tracking closing connection {} to process outstanding requests", channel.id()); } else { @@ -928,7 +935,7 @@ public class Selector implements Selectable, AutoCloseable { } this.sensors.connectionClosed.record(); - this.stagedReceives.remove(channel); + this.completedReceives.remove(channel); this.explicitlyMutedChannels.remove(channel); if (notifyDisconnect) this.disconnected.put(channel.id(), channel.state()); @@ -1005,57 +1012,21 @@ public class Selector implements Selectable, AutoCloseable { } /** - * Check if given channel has a staged receive + * Check if given channel has a completed receive */ - private boolean hasStagedReceive(KafkaChannel channel) { - return stagedReceives.containsKey(channel); + private boolean hasCompletedReceive(KafkaChannel channel) { + return completedReceives.containsKey(channel); } /** - * check if stagedReceives have unmuted channel + * adds a receive to completed receives */ - private boolean hasStagedReceives() { - for (KafkaChannel channel : this.stagedReceives.keySet()) { - if (!channel.isMuted()) - return true; - } - return false; - } + private void addToCompletedReceives(KafkaChannel channel, NetworkReceive networkReceive, long currentTimeMs) { + if (hasCompletedReceive(channel)) + throw new IllegalStateException("Attempting to add second completed receive to channel " + channel.id()); - - /** - * adds a receive to staged receives - */ - private void addToStagedReceives(KafkaChannel channel, NetworkReceive receive) { - if (!stagedReceives.containsKey(channel)) - stagedReceives.put(channel, new ArrayDeque<>()); - - Deque deque = stagedReceives.get(channel); - deque.add(receive); - } - - /** - * checks if there are any staged receives and adds to completedReceives - */ - private void addToCompletedReceives() { - if (!this.stagedReceives.isEmpty()) { - Iterator>> iter = this.stagedReceives.entrySet().iterator(); - while (iter.hasNext()) { - Map.Entry> entry = iter.next(); - KafkaChannel channel = entry.getKey(); - if (!explicitlyMutedChannels.contains(channel)) { - Deque deque = entry.getValue(); - addToCompletedReceives(channel, deque); - if (deque.isEmpty()) - iter.remove(); - } - } - } - } - - private void addToCompletedReceives(KafkaChannel channel, Deque stagedDeque) { - NetworkReceive networkReceive = stagedDeque.poll(); - this.completedReceives.add(networkReceive); + this.completedReceives.put(channel, networkReceive); + sensors.recordCompletedReceive(channel.id(), networkReceive.size(), currentTimeMs); } // only for testing @@ -1063,11 +1034,6 @@ public class Selector implements Selectable, AutoCloseable { return new HashSet<>(nioSelector.keys()); } - // only for testing - public int numStagedReceives(KafkaChannel channel) { - Deque deque = stagedReceives.get(channel); - return deque == null ? 0 : deque.size(); - } class SelectorChannelMetadataRegistry implements ChannelMetadataRegistry { private CipherInformation cipherInformation; diff --git a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java index 6b47b5be954..6784a0149c4 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java +++ b/clients/src/main/java/org/apache/kafka/common/security/authenticator/SaslClientAuthenticator.java @@ -64,11 +64,11 @@ import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.Random; import java.util.Set; @@ -311,8 +311,8 @@ public class SaslClientAuthenticator implements Authenticator { } @Override - public List getAndClearResponsesReceivedDuringReauthentication() { - return reauthInfo.getAndClearResponsesReceivedDuringReauthentication(); + public Optional pollResponseReceivedDuringReauthentication() { + return reauthInfo.pollResponseReceivedDuringReauthentication(); } @Override @@ -602,23 +602,21 @@ public class SaslClientAuthenticator implements Authenticator { } /** - * Return the (always non-null but possibly empty) NetworkReceive responses that - * arrived during re-authentication that are unrelated to re-authentication, if - * any. These correspond to requests sent prior to the beginning of - * re-authentication; the requests were made when the channel was successfully - * authenticated, and the responses arrived during the re-authentication + * Return the (always non-null but possibly empty) NetworkReceive response that + * arrived during re-authentication that is unrelated to re-authentication, if + * any. This corresponds to a request sent prior to the beginning of + * re-authentication; the request was made when the channel was successfully + * authenticated, and the response arrived during the re-authentication * process. * - * @return the (always non-null but possibly empty) NetworkReceive responses - * that arrived during re-authentication that are unrelated to + * @return the (always non-null but possibly empty) NetworkReceive response + * that arrived during re-authentication that is unrelated to * re-authentication, if any */ - public List getAndClearResponsesReceivedDuringReauthentication() { + public Optional pollResponseReceivedDuringReauthentication() { if (pendingAuthenticatedReceives.isEmpty()) - return Collections.emptyList(); - List retval = pendingAuthenticatedReceives; - pendingAuthenticatedReceives = new ArrayList<>(); - return retval; + return Optional.empty(); + return Optional.of(pendingAuthenticatedReceives.remove(0)); } public void setAuthenticationEndAndSessionReauthenticationTimes(long nowNanos) { diff --git a/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java b/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java index 9a72362ca4c..36ac1ed06b5 100644 --- a/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java +++ b/clients/src/test/java/org/apache/kafka/common/network/NioEchoServer.java @@ -40,6 +40,7 @@ import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; import java.nio.channels.WritableByteChannel; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.EnumSet; import java.util.Iterator; @@ -222,7 +223,7 @@ public class NioEchoServer extends Thread { selector.close(channel.id()); } - List completedReceives = selector.completedReceives(); + Collection completedReceives = selector.completedReceives(); for (NetworkReceive rcv : completedReceives) { KafkaChannel channel = channel(rcv.source()); if (!maybeBeginServerReauthentication(channel, rcv, time)) { diff --git a/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java b/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java index 70eb588299d..fd9c7c3c7f8 100644 --- a/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java +++ b/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java @@ -94,7 +94,7 @@ public class SelectorTest { this.server.start(); this.time = new MockTime(); this.channelBuilder = new PlaintextChannelBuilder(ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT)); - this.channelBuilder.configure(configs); + this.channelBuilder.configure(clientConfigs()); this.metrics = new Metrics(); this.selector = new Selector(5000, this.metrics, time, METRIC_GROUP, channelBuilder, new LogContext()); } @@ -114,6 +114,10 @@ public class SelectorTest { return SecurityProtocol.PLAINTEXT; } + protected Map clientConfigs() { + return new HashMap<>(); + } + /** * Validate that when the server disconnects, a client send ends up with that node in the disconnected list. */ @@ -356,14 +360,14 @@ public class SelectorTest { while (selector.completedReceives().isEmpty()) selector.poll(5); assertEquals("We should have only one response", 1, selector.completedReceives().size()); - assertEquals("The response should not be from the muted node", "0", selector.completedReceives().get(0).source()); + assertEquals("The response should not be from the muted node", "0", selector.completedReceives().iterator().next().source()); selector.unmute("1"); do { selector.poll(5); } while (selector.completedReceives().isEmpty()); assertEquals("We should have only one response", 1, selector.completedReceives().size()); - assertEquals("The response should be from the previously muted node", "1", selector.completedReceives().get(0).source()); + assertEquals("The response should be from the previously muted node", "1", selector.completedReceives().iterator().next().source()); } @Test @@ -391,26 +395,6 @@ public class SelectorTest { selector.close(); } - @Test - public void testCloseConnectionInClosingState() throws Exception { - KafkaChannel channel = createConnectionWithStagedReceives(5); - String id = channel.id(); - selector.mute(id); // Mute to allow channel to be expired even if more data is available for read - time.sleep(6000); // The max idle time is 5000ms - selector.poll(0); - assertNull("Channel not expired", selector.channel(id)); - assertEquals(channel, selector.closingChannel(id)); - assertEquals(ChannelState.EXPIRED, channel.state()); - selector.close(id); - assertNull("Channel not removed from channels", selector.channel(id)); - assertNull("Channel not removed from closingChannels", selector.closingChannel(id)); - assertTrue("Unexpected disconnect notification", selector.disconnected().isEmpty()); - assertEquals(ChannelState.EXPIRED, channel.state()); - assertNull(channel.selectionKey().attachment()); - selector.poll(0); - assertTrue("Unexpected disconnect notification", selector.disconnected().isEmpty()); - } - @Test public void testCloseOldestConnection() throws Exception { String id = "0"; @@ -522,71 +506,126 @@ public class SelectorTest { } } + /* + * Verifies that a muted connection is expired on idle timeout even if there are pending + * receives on the socket. + */ @Test - public void testCloseOldestConnectionWithOneStagedReceive() throws Exception { - verifyCloseOldestConnectionWithStagedReceives(1); + public void testExpireConnectionWithPendingReceives() throws Exception { + KafkaChannel channel = createConnectionWithPendingReceives(5); + verifyChannelExpiry(channel); } + /** + * Verifies that a muted connection closed by peer is expired on idle timeout even if there are pending + * receives on the socket. + */ @Test - public void testCloseOldestConnectionWithMultipleStagedReceives() throws Exception { - verifyCloseOldestConnectionWithStagedReceives(5); + public void testExpireClosedConnectionWithPendingReceives() throws Exception { + KafkaChannel channel = createConnectionWithPendingReceives(5); + server.closeConnections(); + verifyChannelExpiry(channel); } - private KafkaChannel createConnectionWithStagedReceives(int maxStagedReceives) throws Exception { - String id = "0"; - blockingConnect(id); - KafkaChannel channel = selector.channel(id); - int retries = 100; - - do { - selector.mute(id); - for (int i = 0; i <= maxStagedReceives; i++) { - selector.send(createSend(id, String.valueOf(i))); - do { - selector.poll(1000); - } while (selector.completedSends().isEmpty()); - } - - selector.unmute(id); - do { - selector.poll(1000); - } while (selector.completedReceives().isEmpty()); - } while ((selector.numStagedReceives(channel) == 0 || channel.hasBytesBuffered()) && --retries > 0); - assertTrue("No staged receives after 100 attempts", selector.numStagedReceives(channel) > 0); - // We want to return without any bytes buffered to ensure that channel will be closed after idle time - assertFalse("Channel has bytes buffered", channel.hasBytesBuffered()); - - return channel; - } - - private void verifyCloseOldestConnectionWithStagedReceives(int maxStagedReceives) throws Exception { - KafkaChannel channel = createConnectionWithStagedReceives(maxStagedReceives); + private void verifyChannelExpiry(KafkaChannel channel) throws Exception { + String id = channel.id(); + selector.mute(id); // Mute to allow channel to be expired even if more data is available for read + time.sleep(6000); // The max idle time is 5000ms + selector.poll(0); + assertNull("Channel not expired", selector.channel(id)); + assertNull("Channel not removed from closingChannels", selector.closingChannel(id)); + assertEquals(ChannelState.EXPIRED, channel.state()); + assertNull(channel.selectionKey().attachment()); + assertTrue("Disconnect not notified", selector.disconnected().containsKey(id)); + assertEquals(ChannelState.EXPIRED, selector.disconnected().get(id)); + verifySelectorEmpty(); + } + + /** + * Verifies that sockets with incoming data available are not expired. + * For PLAINTEXT, pending receives are always read from socket without any buffering, so this + * test is only verifying that channels are not expired while there is data to read from socket. + * For SSL, pending receives may also be in SSL netReadBuffer or appReadBuffer. So the test verifies + * that connection is not expired when data is available from buffers or network. + */ + @Test + public void testCloseOldestConnectionWithMultiplePendingReceives() throws Exception { + int expectedReceives = 5; + KafkaChannel channel = createConnectionWithPendingReceives(expectedReceives); String id = channel.id(); - int stagedReceives = selector.numStagedReceives(channel); int completedReceives = 0; while (selector.disconnected().isEmpty()) { time.sleep(6000); // The max idle time is 5000ms - selector.poll(0); + selector.poll(completedReceives == expectedReceives ? 0 : 1000); completedReceives += selector.completedReceives().size(); - // With SSL, more receives may be staged from buffered data - int newStaged = selector.numStagedReceives(channel) - (stagedReceives - completedReceives); - if (newStaged > 0) { - stagedReceives += newStaged; - assertNotNull("Channel should not have been expired", selector.channel(id)); - assertFalse("Channel should not have been disconnected", selector.disconnected().containsKey(id)); - } else if (!selector.completedReceives().isEmpty()) { + if (!selector.completedReceives().isEmpty()) { assertEquals(1, selector.completedReceives().size()); + assertNotNull("Channel should not have been expired", selector.channel(id)); assertTrue("Channel not found", selector.closingChannel(id) != null || selector.channel(id) != null); assertFalse("Disconnect notified too early", selector.disconnected().containsKey(id)); } } - assertEquals(stagedReceives, completedReceives); + assertEquals(expectedReceives, completedReceives); assertNull("Channel not removed", selector.channel(id)); assertNull("Channel not removed", selector.closingChannel(id)); assertTrue("Disconnect not notified", selector.disconnected().containsKey(id)); assertTrue("Unexpected receive", selector.completedReceives().isEmpty()); } + /** + * Tests that graceful close of channel processes remaining data from socket read buffers. + * Since we cannot determine how much data is available in the buffers, this test verifies that + * multiple receives are completed after server shuts down connections, with retries to tolerate + * cases where data may not be available in the socket buffer. + */ + @Test + public void testGracefulClose() throws Exception { + int maxReceiveCountAfterClose = 0; + for (int i = 6; i <= 100 && maxReceiveCountAfterClose < 5; i++) { + int receiveCount = 0; + KafkaChannel channel = createConnectionWithPendingReceives(i); + selector.poll(1000); + assertEquals(1, selector.completedReceives().size()); // wait for first receive + server.closeConnections(); + while (selector.disconnected().isEmpty()) { + selector.poll(1); + receiveCount += selector.completedReceives().size(); + assertTrue("Too many completed receives in one poll", selector.completedReceives().size() <= 1); + } + assertEquals(channel.id(), selector.disconnected().keySet().iterator().next()); + maxReceiveCountAfterClose = Math.max(maxReceiveCountAfterClose, receiveCount); + } + assertTrue("Too few receives after close: " + maxReceiveCountAfterClose, maxReceiveCountAfterClose >= 5); + } + + /** + * Tests that graceful close is not delayed if only part of an incoming receive is + * available in the socket buffer. + */ + @Test + public void testPartialReceiveGracefulClose() throws Exception { + String id = "0"; + blockingConnect(id); + KafkaChannel channel = selector.channel(id); + // Inject a NetworkReceive into Kafka channel with a large size + injectNetworkReceive(channel, 100000); + sendNoReceive(channel, 2); // Send some data that gets received as part of injected receive + selector.poll(1000); // Wait until some data arrives, but not a completed receive + assertEquals(0, selector.completedReceives().size()); + server.closeConnections(); + TestUtils.waitForCondition(() -> { + try { + selector.poll(100); + return !selector.disconnected().isEmpty(); + } catch (IOException e) { + throw new RuntimeException(e); + } + }, 10000, "Channel not disconnected"); + assertEquals(1, selector.disconnected().size()); + assertEquals(channel.id(), selector.disconnected().keySet().iterator().next()); + assertEquals(0, selector.completedReceives().size()); + } + @Test public void testMuteOnOOM() throws Exception { //clean up default selector, replace it with one that uses a finite mem pool @@ -617,14 +656,14 @@ public class SelectorTest { selector.register("clientX", channelX); selector.register("clientY", channelY); - List completed = Collections.emptyList(); + Collection completed = Collections.emptyList(); long deadline = System.currentTimeMillis() + 5000; while (System.currentTimeMillis() < deadline && completed.isEmpty()) { selector.poll(1000); completed = selector.completedReceives(); } assertEquals("could not read a single request within timeout", 1, completed.size()); - NetworkReceive firstReceive = completed.get(0); + NetworkReceive firstReceive = completed.iterator().next(); assertEquals(0, pool.availableMemory()); assertTrue(selector.isOutOfMemory()); @@ -976,4 +1015,47 @@ public class SelectorTest { assertNotNull(metric); return metric; } + + /** + * Creates a connection, sends the specified number of requests and returns without reading + * any incoming data. Some of the incoming data may be in the socket buffers when this method + * returns, but there is no guarantee that all the data from the server will be available + * immediately. + */ + private KafkaChannel createConnectionWithPendingReceives(int pendingReceives) throws Exception { + String id = "0"; + blockingConnect(id); + KafkaChannel channel = selector.channel(id); + sendNoReceive(channel, pendingReceives); + return channel; + } + + /** + * Sends the specified number of requests and waits for the requests to be sent. The channel + * is muted during polling to ensure that incoming data is not received. + */ + private KafkaChannel sendNoReceive(KafkaChannel channel, int numRequests) throws Exception { + channel.mute(); + for (int i = 0; i < numRequests; i++) { + selector.send(createSend(channel.id(), String.valueOf(i))); + do { + selector.poll(10); + } while (selector.completedSends().isEmpty()); + } + channel.maybeUnmute(); + + return channel; + } + + /** + * Injects a NetworkReceive for channel with size buffer filled in with the provided size + * and a payload buffer allocated with that size, but no data in the payload buffer. + */ + private void injectNetworkReceive(KafkaChannel channel, int size) throws Exception { + NetworkReceive receive = new NetworkReceive(); + TestUtils.setFieldValue(channel, "receive", receive); + ByteBuffer sizeBuffer = TestUtils.fieldValue(receive, NetworkReceive.class, "size"); + sizeBuffer.putInt(size); + TestUtils.setFieldValue(receive, "buffer", ByteBuffer.allocate(size)); + } } diff --git a/clients/src/test/java/org/apache/kafka/common/network/SslSelectorTest.java b/clients/src/test/java/org/apache/kafka/common/network/SslSelectorTest.java index 86c760ec7fb..8a36fbfa594 100644 --- a/clients/src/test/java/org/apache/kafka/common/network/SslSelectorTest.java +++ b/clients/src/test/java/org/apache/kafka/common/network/SslSelectorTest.java @@ -44,6 +44,7 @@ import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; import java.security.Security; import java.util.ArrayList; +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.List; @@ -92,6 +93,11 @@ public class SslSelectorTest extends SelectorTest { return SecurityProtocol.PLAINTEXT; } + @Override + protected Map clientConfigs() { + return sslClientConfigs; + } + @Test public void testConnectionWithCustomKeyManager() throws Exception { @@ -315,11 +321,11 @@ public class SslSelectorTest extends SelectorTest { while (System.currentTimeMillis() < deadline) { selector.poll(10); - List completed = selector.completedReceives(); + Collection completed = selector.completedReceives(); if (firstReceive == null) { if (!completed.isEmpty()) { assertEquals("expecting a single request", 1, completed.size()); - firstReceive = completed.get(0); + firstReceive = completed.iterator().next(); assertTrue(selector.isMadeReadProgressLastPoll()); assertEquals(0, pool.availableMemory()); } @@ -343,7 +349,7 @@ public class SslSelectorTest extends SelectorTest { firstReceive.close(); assertEquals(900, pool.availableMemory()); //memory has been released back to pool - List completed = Collections.emptyList(); + Collection completed = Collections.emptyList(); deadline = System.currentTimeMillis() + 5000; while (System.currentTimeMillis() < deadline && completed.isEmpty()) { selector.poll(1000); diff --git a/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java b/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java index 3b80d4b89a1..2809cbd8834 100644 --- a/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java +++ b/clients/src/test/java/org/apache/kafka/common/network/SslTransportLayerTest.java @@ -652,9 +652,9 @@ public class SslTransportLayerTest { // Read the message from socket with only one poll() selector.poll(1000L); - List receiveList = selector.completedReceives(); + Collection receiveList = selector.completedReceives(); assertEquals(1, receiveList.size()); - assertEquals(message, new String(Utils.toArray(receiveList.get(0).payload()))); + assertEquals(message, new String(Utils.toArray(receiveList.iterator().next().payload()))); } /** @@ -737,7 +737,6 @@ public class SslTransportLayerTest { public boolean conditionMet() { try { selector.poll(100L); - assertEquals(0, selector.numStagedReceives(channel)); } catch (IOException e) { return false; } diff --git a/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java index 76894a9991c..0e1080bcef3 100644 --- a/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java +++ b/clients/src/test/java/org/apache/kafka/common/security/authenticator/SaslAuthenticatorTest.java @@ -2051,7 +2051,7 @@ public class SaslAuthenticatorTest { selector.poll(1000); } while (selector.completedReceives().isEmpty() && waitSeconds-- > 0); assertEquals(1, selector.completedReceives().size()); - return selector.completedReceives().get(0).payload(); + return selector.completedReceives().iterator().next().payload(); } public static class TestServerCallbackHandler extends PlainServerCallbackHandler { diff --git a/clients/src/test/java/org/apache/kafka/test/TestUtils.java b/clients/src/test/java/org/apache/kafka/test/TestUtils.java index ece5af30770..a53d1fcf2a5 100644 --- a/clients/src/test/java/org/apache/kafka/test/TestUtils.java +++ b/clients/src/test/java/org/apache/kafka/test/TestUtils.java @@ -562,4 +562,10 @@ public class TestUtils { throw new RuntimeException(e); } } + + public static void setFieldValue(Object obj, String fieldName, Object value) throws Exception { + Field field = obj.getClass().getDeclaredField(fieldName); + field.setAccessible(true); + field.set(obj, value); + } } diff --git a/core/src/main/scala/kafka/network/SocketServer.scala b/core/src/main/scala/kafka/network/SocketServer.scala index e7cc81ae1ed..35d9d7cc6e7 100644 --- a/core/src/main/scala/kafka/network/SocketServer.scala +++ b/core/src/main/scala/kafka/network/SocketServer.scala @@ -1092,10 +1092,6 @@ private[kafka] class Processor(val id: Int, private[network] def channel(connectionId: String): Option[KafkaChannel] = Option(selector.channel(connectionId)) - // Visible for testing - private[network] def numStagedReceives(connectionId: String): Int = - openOrClosingChannel(connectionId).map(c => selector.numStagedReceives(c)).getOrElse(0) - /** * Wakeup the thread for selection. */ diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala index 33dada06c23..7c77a979dbf 100644 --- a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala +++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala @@ -20,13 +20,15 @@ package kafka.network import java.io._ import java.net._ import java.nio.ByteBuffer -import java.nio.channels.SocketChannel -import java.util.concurrent.{CompletableFuture, Executors} +import java.nio.channels.{SelectionKey, SocketChannel} +import java.util +import java.util.concurrent.{CompletableFuture, ConcurrentLinkedQueue, Executors, TimeUnit} import java.util.{HashMap, Properties, Random} import com.yammer.metrics.core.{Gauge, Meter} import com.yammer.metrics.{Metrics => YammerMetrics} import javax.net.ssl._ + import kafka.security.CredentialProvider import kafka.server.{KafkaConfig, ThrottledChannel} import kafka.utils.Implicits._ @@ -38,7 +40,7 @@ import org.apache.kafka.common.message.SaslHandshakeRequestData import org.apache.kafka.common.metrics.Metrics import org.apache.kafka.common.network.ClientInformation import org.apache.kafka.common.network.KafkaChannel.ChannelMuteState -import org.apache.kafka.common.network.{ChannelBuilder, ChannelState, KafkaChannel, ListenerName, NetworkReceive, NetworkSend, Selector, Send} +import org.apache.kafka.common.network._ import org.apache.kafka.common.protocol.{ApiKeys, Errors} import org.apache.kafka.common.record.MemoryRecords import org.apache.kafka.common.requests.{AbstractRequest, ApiVersionsRequest, ProduceRequest, RequestHeader, SaslAuthenticateRequest, SaslHandshakeRequest} @@ -46,7 +48,7 @@ import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol} import org.apache.kafka.common.security.scram.internals.ScramMechanism import org.apache.kafka.common.utils.AppInfoParser import org.apache.kafka.common.utils.{LogContext, MockTime, Time} -import org.apache.kafka.test.TestSslUtils +import org.apache.kafka.test.{TestSslUtils, TestUtils => JTestUtils} import org.apache.log4j.Level import org.junit.Assert._ import org.junit._ @@ -148,18 +150,45 @@ class SocketServerTest { channel.sendResponse(new RequestChannel.SendResponse(request, send, Some(request.header.toString), None)) } + def processRequestNoOpResponse(channel: RequestChannel, request: RequestChannel.Request): Unit = { + channel.sendResponse(new RequestChannel.NoOpResponse(request)) + } + def connect(s: SocketServer = server, listenerName: ListenerName = ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT), localAddr: InetAddress = null, - port: Int = 0) = { + port: Int = 0): Socket = { val socket = new Socket("localhost", s.boundPort(listenerName), localAddr, port) sockets += socket socket } + def sslConnect(s: SocketServer = server): Socket = { + val socket = sslClientSocket(s.boundPort(ListenerName.forSecurityProtocol(SecurityProtocol.SSL))) + sockets += socket + socket + } + + private def sslClientSocket(port: Int): Socket = { + val sslContext = SSLContext.getInstance("TLSv1.2") + sslContext.init(null, Array(TestUtils.trustAllCerts), new java.security.SecureRandom()) + val socketFactory = sslContext.getSocketFactory + val socket = socketFactory.createSocket("localhost", port) + socket.asInstanceOf[SSLSocket].setNeedClientAuth(false) + socket + } + // Create a client connection, process one request and return (client socket, connectionId) def connectAndProcessRequest(s: SocketServer): (Socket, String) = { - val socket = connect(s) + val securityProtocol = s.dataPlaneAcceptors.asScala.head._1.securityProtocol + val socket = securityProtocol match { + case SecurityProtocol.PLAINTEXT | SecurityProtocol.SASL_PLAINTEXT => + connect(s) + case SecurityProtocol.SSL | SecurityProtocol.SASL_SSL => + sslConnect(s) + case _ => + throw new IllegalStateException(s"Unexpected security protocol $securityProtocol") + } val request = sendAndReceiveRequest(socket, s) processRequest(s.dataPlaneRequestChannel, request) (socket, request.context.connectionId) @@ -357,7 +386,7 @@ class SocketServerTest { for (_ <- 0 until 10) { val request = receiveRequest(server.dataPlaneRequestChannel) assertNotNull("receiveRequest timed out", request) - server.dataPlaneRequestChannel.sendResponse(new RequestChannel.NoOpResponse(request)) + processRequestNoOpResponse(server.dataPlaneRequestChannel, request) } } @@ -371,7 +400,7 @@ class SocketServerTest { for (_ <- 0 until 3) { val request = receiveRequest(server.dataPlaneRequestChannel) assertNotNull("receiveRequest timed out", request) - server.dataPlaneRequestChannel.sendResponse(new RequestChannel.NoOpResponse(request)) + processRequestNoOpResponse(server.dataPlaneRequestChannel, request) } } @@ -400,37 +429,41 @@ class SocketServerTest { val serverMetrics = new Metrics val overrideServer = new SocketServer(KafkaConfig.fromProps(props), serverMetrics, time, credentialProvider) - def openChannel(request: RequestChannel.Request): Option[KafkaChannel] = - overrideServer.dataPlaneProcessor(request.processor).channel(request.context.connectionId) - def openOrClosingChannel(request: RequestChannel.Request): Option[KafkaChannel] = - overrideServer.dataPlaneProcessor(request.processor).openOrClosingChannel(request.context.connectionId) - try { overrideServer.startup() val serializedBytes = producerRequestBytes() - // Connection with no staged receives + // Connection with no outstanding requests + val socket0 = connect(overrideServer) + sendRequest(socket0, serializedBytes) + val request0 = receiveRequest(overrideServer.dataPlaneRequestChannel) + processRequest(overrideServer.dataPlaneRequestChannel, request0) + assertTrue("Channel not open", openChannel(request0, overrideServer).nonEmpty) + assertEquals(openChannel(request0, overrideServer), openOrClosingChannel(request0, overrideServer)) + TestUtils.waitUntilTrue(() => !openChannel(request0, overrideServer).get.isMuted, "Failed to unmute channel") + time.sleep(idleTimeMs + 1) + TestUtils.waitUntilTrue(() => openOrClosingChannel(request0, overrideServer).isEmpty, "Failed to close idle channel") + assertTrue("Channel not removed", openChannel(request0, overrideServer).isEmpty) + + // Connection with one request being processed (channel is muted), no other in-flight requests val socket1 = connect(overrideServer) sendRequest(socket1, serializedBytes) val request1 = receiveRequest(overrideServer.dataPlaneRequestChannel) - assertTrue("Channel not open", openChannel(request1).nonEmpty) - assertEquals(openChannel(request1), openOrClosingChannel(request1)) - + assertTrue("Channel not open", openChannel(request1, overrideServer).nonEmpty) + assertEquals(openChannel(request1, overrideServer), openOrClosingChannel(request1, overrideServer)) time.sleep(idleTimeMs + 1) - TestUtils.waitUntilTrue(() => openOrClosingChannel(request1).isEmpty, "Failed to close idle channel") - assertTrue("Channel not removed", openChannel(request1).isEmpty) + TestUtils.waitUntilTrue(() => openOrClosingChannel(request1, overrideServer).isEmpty, "Failed to close idle channel") + assertTrue("Channel not removed", openChannel(request1, overrideServer).isEmpty) processRequest(overrideServer.dataPlaneRequestChannel, request1) - // Connection with staged receives + // Connection with one request being processed (channel is muted), more in-flight requests val socket2 = connect(overrideServer) - val request2 = sendRequestsUntilStagedReceive(overrideServer, socket2, serializedBytes) - + val request2 = sendRequestsReceiveOne(overrideServer, socket2, serializedBytes, 3) time.sleep(idleTimeMs + 1) - TestUtils.waitUntilTrue(() => openChannel(request2).isEmpty, "Failed to close idle channel") - TestUtils.waitUntilTrue(() => openOrClosingChannel(request2).nonEmpty, "Channel removed without processing staged receives") + TestUtils.waitUntilTrue(() => openOrClosingChannel(request2, overrideServer).isEmpty, "Failed to close idle channel") + assertTrue("Channel not removed", openChannel(request1, overrideServer).isEmpty) processRequest(overrideServer.dataPlaneRequestChannel, request2) // this triggers a failed send since channel has been closed - TestUtils.waitUntilTrue(() => openOrClosingChannel(request2).isEmpty, "Failed to remove channel with failed sends") - assertNull("Received request after failed send", overrideServer.dataPlaneRequestChannel.receiveRequest(200)) + assertNull("Received request on expired channel", overrideServer.dataPlaneRequestChannel.receiveRequest(200)) } finally { shutdownServerAndMetrics(overrideServer) @@ -442,7 +475,7 @@ class SocketServerTest { val idleTimeMs = 60000 val time = new MockTime() props.put(KafkaConfig.ConnectionsMaxIdleMsProp, idleTimeMs.toString) - props.put("listeners", "PLAINTEXT://localhost:0") + props ++= sslServerProps val serverMetrics = new Metrics @volatile var selector: TestableSelector = null val overrideConnectionId = "127.0.0.1:1-127.0.0.1:2-0" @@ -471,7 +504,7 @@ class SocketServerTest { // only after `register` is processed by the server. def connectAndWaitForConnectionRegister(): Socket = { val connections = selector.operationCounts(SelectorOperation.Register) - val socket = connect(overrideServer) + val socket = sslConnect(overrideServer) TestUtils.waitUntilTrue(() => selector.operationCounts(SelectorOperation.Register) == connections + 1, "Connection not registered") socket @@ -488,22 +521,20 @@ class SocketServerTest { connectAndWaitForConnectionRegister() TestUtils.waitUntilTrue(() => connectionCount == 1, "Failed to close channel") assertSame(channel1, openChannel.getOrElse(throw new RuntimeException("Channel not found"))) + socket1.close() + TestUtils.waitUntilTrue(() => openChannel.isEmpty, "Channel not closed") - // Send requests to `channel1` until a receive is staged and advance time beyond idle time so that `channel1` is - // closed with staged receives and is in Selector.closingChannels - val serializedBytes = producerRequestBytes() - val request = sendRequestsUntilStagedReceive(overrideServer, socket1, serializedBytes) - time.sleep(idleTimeMs + 1) - TestUtils.waitUntilTrue(() => openChannel.isEmpty, "Idle channel not closed") - TestUtils.waitUntilTrue(() => openOrClosingChannel.isDefined, "Channel removed without processing staged receives") + // Create a channel with buffered receive and close remote connection + val request = makeChannelWithBufferedRequestsAndCloseRemote(overrideServer, selector) + val channel2 = openChannel.getOrElse(throw new RuntimeException("Channel not found")) - // Create new connection with same id when `channel1` is in Selector.closingChannels - // Check that new connection is closed and openOrClosingChannel still contains `channel1` + // Create new connection with same id when `channel2` is closing, but still in Selector.channels + // Check that new connection is closed and openOrClosingChannel still contains `channel2` connectAndWaitForConnectionRegister() TestUtils.waitUntilTrue(() => connectionCount == 1, "Failed to close channel") - assertSame(channel1, openOrClosingChannel.getOrElse(throw new RuntimeException("Channel not found"))) + assertSame(channel2, openOrClosingChannel.getOrElse(throw new RuntimeException("Channel not found"))) - // Complete request with failed send so that `channel1` is removed from Selector.closingChannels + // Complete request with failed send so that `channel2` is removed from Selector.channels processRequest(overrideServer.dataPlaneRequestChannel, request) TestUtils.waitUntilTrue(() => connectionCount == 0 && openOrClosingChannel.isEmpty, "Failed to remove channel with failed send") @@ -519,23 +550,91 @@ class SocketServerTest { } } - private def sendRequestsUntilStagedReceive(server: SocketServer, socket: Socket, requestBytes: Array[Byte]): RequestChannel.Request = { - def sendTwoRequestsReceiveOne(): RequestChannel.Request = { - sendRequest(socket, requestBytes, flush = false) - sendRequest(socket, requestBytes, flush = true) - receiveRequest(server.dataPlaneRequestChannel) + private def makeSocketWithBufferedRequests(server: SocketServer, + serverSelector: Selector, + proxyServer: ProxyServer, + numBufferedRequests: Int = 2): (Socket, RequestChannel.Request) = { + + val requestBytes = producerRequestBytes() + val socket = sslClientSocket(proxyServer.localPort) + sendRequest(socket, requestBytes) + val request1 = receiveRequest(server.dataPlaneRequestChannel) + + val connectionId = request1.context.connectionId + val channel = server.dataPlaneProcessor(0).channel(connectionId).getOrElse(throw new IllegalStateException("Channel not found")) + val transportLayer: SslTransportLayer = JTestUtils.fieldValue(channel, classOf[KafkaChannel], "transportLayer") + val netReadBuffer: ByteBuffer = JTestUtils.fieldValue(transportLayer, classOf[SslTransportLayer], "netReadBuffer") + + proxyServer.enableBuffering(netReadBuffer) + (1 to numBufferedRequests).foreach { _ => sendRequest(socket, requestBytes) } + + val keysWithBufferedRead: util.Set[SelectionKey] = JTestUtils.fieldValue(serverSelector, classOf[Selector], "keysWithBufferedRead") + keysWithBufferedRead.add(channel.selectionKey) + JTestUtils.setFieldValue(transportLayer, "hasBytesBuffered", true) + + (socket, request1) + } + + /** + * Create a channel with data in SSL buffers and close the remote connection. + * The channel should remain open in SocketServer even if it detects that the peer has closed + * the connection since there is pending data to be processed. + */ + private def makeChannelWithBufferedRequestsAndCloseRemote(server: SocketServer, + serverSelector: Selector, + makeClosing: Boolean = false): RequestChannel.Request = { + + val proxyServer = new ProxyServer(server) + try { + val (socket, request1) = makeSocketWithBufferedRequests(server, serverSelector, proxyServer) + + socket.close() + proxyServer.serverConnSocket.close() + TestUtils.waitUntilTrue(() => proxyServer.clientConnSocket.isClosed, "Client socket not closed", waitTimeMs = 10000) + + processRequestNoOpResponse(server.dataPlaneRequestChannel, request1) + val channel = openOrClosingChannel(request1, server).getOrElse(throw new IllegalStateException("Channel closed too early")) + if (makeClosing) + serverSelector.asInstanceOf[TestableSelector].pendingClosingChannels.add(channel) + + receiveRequest(server.dataPlaneRequestChannel, timeout = 10000) + } finally { + proxyServer.close() } - val (request, hasStagedReceives) = TestUtils.computeUntilTrue(sendTwoRequestsReceiveOne()) { req => - val connectionId = req.context.connectionId - val hasStagedReceives = server.dataPlaneProcessor(0).numStagedReceives(connectionId) > 0 - if (!hasStagedReceives) { - processRequest(server.dataPlaneRequestChannel, req) - processRequest(server.dataPlaneRequestChannel) + } + + def sendRequestsReceiveOne(server: SocketServer, socket: Socket, requestBytes: Array[Byte], numRequests: Int): RequestChannel.Request = { + (1 to numRequests).foreach(i => sendRequest(socket, requestBytes, flush = i == numRequests)) + receiveRequest(server.dataPlaneRequestChannel) + } + + private def closeSocketWithPendingRequest(server: SocketServer, + createSocket: () => Socket): RequestChannel.Request = { + + def maybeReceiveRequest(): Option[RequestChannel.Request] = { + try { + Some(receiveRequest(server.dataPlaneRequestChannel, timeout = 1000)) + } catch { + case e: Exception => None } - hasStagedReceives } - assertTrue(s"Receives not staged for ${org.apache.kafka.test.TestUtils.DEFAULT_MAX_WAIT_MS} ms", hasStagedReceives) - request + + def closedChannelWithPendingRequest(): Option[RequestChannel.Request] = { + val socket = createSocket.apply() + val req1 = sendRequestsReceiveOne(server, socket, producerRequestBytes(ack = 0), numRequests = 100) + processRequestNoOpResponse(server.dataPlaneRequestChannel, req1) + // Set SoLinger to 0 to force a hard disconnect via TCP RST + socket.setSoLinger(true, 0) + socket.close() + + maybeReceiveRequest().flatMap { req => + processRequestNoOpResponse(server.dataPlaneRequestChannel, req) + maybeReceiveRequest() + } + } + + val (request, _) = TestUtils.computeUntilTrue(closedChannelWithPendingRequest()) { req => req.nonEmpty } + request.getOrElse(throw new IllegalStateException("Could not create close channel with pending request")) } // Prepares test setup for throttled channel tests. throttlingDone controls whether or not throttling has completed @@ -568,7 +667,10 @@ class SocketServerTest { request } - def openOrClosingChannel(request: RequestChannel.Request): Option[KafkaChannel] = + def openChannel(request: RequestChannel.Request, server: SocketServer = this.server): Option[KafkaChannel] = + server.dataPlaneProcessor(0).channel(request.context.connectionId) + + def openOrClosingChannel(request: RequestChannel.Request, server: SocketServer = this.server): Option[KafkaChannel] = server.dataPlaneProcessor(0).openOrClosingChannel(request.context.connectionId) @Test @@ -740,13 +842,8 @@ class SocketServerTest { @Test def testSslSocketServer(): Unit = { - val trustStoreFile = File.createTempFile("truststore", ".jks") - val overrideProps = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, interBrokerSecurityProtocol = Some(SecurityProtocol.SSL), - trustStoreFile = Some(trustStoreFile)) - overrideProps.put(KafkaConfig.ListenersProp, "SSL://localhost:0") - val serverMetrics = new Metrics - val overrideServer = new SocketServer(KafkaConfig.fromProps(overrideProps), serverMetrics, Time.SYSTEM, credentialProvider) + val overrideServer = new SocketServer(KafkaConfig.fromProps(sslServerProps), serverMetrics, Time.SYSTEM, credentialProvider) try { overrideServer.startup() val sslContext = SSLContext.getInstance(TestSslUtils.DEFAULT_TLS_PROTOCOL_FOR_TESTS) @@ -919,17 +1016,15 @@ class SocketServerTest { } @Test - def testClientDisconnectionWithStagedReceivesFullyProcessed(): Unit = { + def testClientDisconnectionWithOutstandingReceivesProcessedUntilFailedSend() { val serverMetrics = new Metrics @volatile var selector: TestableSelector = null - val overrideConnectionId = "127.0.0.1:1-127.0.0.1:2-0" val overrideServer = new SocketServer(KafkaConfig.fromProps(props), serverMetrics, Time.SYSTEM, credentialProvider) { override def newProcessor(id: Int, requestChannel: RequestChannel, connectionQuotas: ConnectionQuotas, listenerName: ListenerName, protocol: SecurityProtocol, memoryPool: MemoryPool): Processor = { new Processor(id, time, config.socketRequestMaxBytes, dataPlaneRequestChannel, connectionQuotas, config.connectionsMaxIdleMs, config.failedAuthenticationDelayMs, listenerName, protocol, config, metrics, credentialProvider, memoryPool, new LogContext()) { - override protected[network] def connectionId(socket: Socket): String = overrideConnectionId override protected[network] def createSelector(channelBuilder: ChannelBuilder): Selector = { val testableSelector = new TestableSelector(config, channelBuilder, time, metrics) selector = testableSelector @@ -939,27 +1034,15 @@ class SocketServerTest { } } - def openChannel: Option[KafkaChannel] = overrideServer.dataPlaneProcessor(0).channel(overrideConnectionId) - def openOrClosingChannel: Option[KafkaChannel] = overrideServer.dataPlaneProcessor(0).openOrClosingChannel(overrideConnectionId) - try { overrideServer.startup() - val socket = connect(overrideServer) - TestUtils.waitUntilTrue(() => openChannel.nonEmpty, "Channel not found") + // Create a channel, send some requests and close socket. Receive one pending request after socket was closed. + val request = closeSocketWithPendingRequest(overrideServer, () => connect(overrideServer)) - // Setup channel to client with staged receives so when client disconnects - // it will be stored in Selector.closingChannels - val serializedBytes = producerRequestBytes(1) - val request = sendRequestsUntilStagedReceive(overrideServer, socket, serializedBytes) - - // Set SoLinger to 0 to force a hard disconnect via TCP RST - socket.setSoLinger(true, 0) - socket.close() - - // Complete request with socket exception so that the channel is removed from Selector.closingChannels + // Complete request with socket exception so that the channel is closed processRequest(overrideServer.dataPlaneRequestChannel, request) - TestUtils.waitUntilTrue(() => openOrClosingChannel.isEmpty, "Channel not closed after failed send") + TestUtils.waitUntilTrue(() => openOrClosingChannel(request, overrideServer).isEmpty, "Channel not closed after failed send") assertTrue("Unexpected completed send", selector.completedSends.isEmpty) } finally { overrideServer.shutdown() @@ -1143,30 +1226,233 @@ class SocketServerTest { } /** - * Tests exception handling in [[Processor.processNewResponses]] when [[Selector.send]] - * to a channel in closing state throws an exception. Test scenario is similar to - * [[SocketServerTest.processNewResponseException]]. + * Tests channel send failure handling when send failure is triggered by [[Selector.send]] + * to a channel whose peer has closed its connection. */ @Test - def closingChannelException(): Unit = { + def remoteCloseSendFailure(): Unit = { + verifySendFailureAfterRemoteClose(makeClosing = false) + } + + /** + * Tests channel send failure handling when send failure is triggered by [[Selector.send]] + * to a channel whose peer has closed its connection and the channel is in `closingChannels`. + */ + @Test + def closingChannelSendFailure(): Unit = { + verifySendFailureAfterRemoteClose(makeClosing = true) + } + + private def verifySendFailureAfterRemoteClose(makeClosing: Boolean): Unit = { + props ++= sslServerProps withTestableServer (testWithServer = { testableServer => val testableSelector = testableServer.testableSelector + + val serializedBytes = producerRequestBytes() + val request = makeChannelWithBufferedRequestsAndCloseRemote(testableServer, testableSelector, makeClosing) + val otherSocket = sslConnect(testableServer) + sendRequest(otherSocket, serializedBytes) + + processRequest(testableServer.dataPlaneRequestChannel, request) + processRequest(testableServer.dataPlaneRequestChannel) // Also process request from other socket + testableSelector.waitForOperations(SelectorOperation.Send, 2) + testableServer.waitForChannelClose(request.context.connectionId, locallyClosed = false) + + assertProcessorHealthy(testableServer, Seq(otherSocket)) + }) + } + + /** + * Verifies that all pending buffered receives are processed even if remote connection is closed. + * The channel must be closed after pending receives are processed. + */ + @Test + def remoteCloseWithBufferedReceives(): Unit = { + verifyRemoteCloseWithBufferedReceives(numComplete = 3, hasIncomplete = false) + } + + /** + * Verifies that channel is closed when remote client closes its connection if there is no + * buffered receive. + */ + @Test + def remoteCloseWithoutBufferedReceives(): Unit = { + verifyRemoteCloseWithBufferedReceives(numComplete = 0, hasIncomplete = false) + } + + /** + * Verifies that channel is closed when remote client closes its connection if there is a pending + * receive that is incomplete. + */ + @Test + def remoteCloseWithIncompleteBufferedReceive(): Unit = { + verifyRemoteCloseWithBufferedReceives(numComplete = 0, hasIncomplete = true) + } + + /** + * Verifies that all pending buffered receives are processed even if remote connection is closed. + * The channel must be closed after complete receives are processed, even if there is an incomplete + * receive remaining in the buffers. + */ + @Test + def remoteCloseWithCompleteAndIncompleteBufferedReceives(): Unit = { + verifyRemoteCloseWithBufferedReceives(numComplete = 3, hasIncomplete = true) + } + + /** + * Verifies that pending buffered receives are processed when remote connection is closed + * until a response send fails. + */ + @Test + def remoteCloseWithBufferedReceivesFailedSend(): Unit = { + verifyRemoteCloseWithBufferedReceives(numComplete = 3, hasIncomplete = false, responseRequiredIndex = 1) + } + + /** + * Verifies that all pending buffered receives are processed for channel in closing state. + * The channel must be closed after pending receives are processed. + */ + @Test + def closingChannelWithBufferedReceives(): Unit = { + verifyRemoteCloseWithBufferedReceives(numComplete = 3, hasIncomplete = false, makeClosing = true) + } + + /** + * Verifies that all pending buffered receives are processed for channel in closing state. + * The channel must be closed after complete receives are processed, even if there is an incomplete + * receive remaining in the buffers. + */ + @Test + def closingChannelWithCompleteAndIncompleteBufferedReceives(): Unit = { + verifyRemoteCloseWithBufferedReceives(numComplete = 3, hasIncomplete = true, makeClosing = false) + } + + /** + * Verifies that pending buffered receives are processed for a channel in closing state + * until a response send fails. + */ + @Test + def closingChannelWithBufferedReceivesFailedSend(): Unit = { + verifyRemoteCloseWithBufferedReceives(numComplete = 3, hasIncomplete = false, responseRequiredIndex = 1, makeClosing = false) + } + + /** + * Verifies handling of client disconnections when the server-side channel is in the state + * specified using the parameters. + * + * @param numComplete Number of complete buffered requests + * @param hasIncomplete If true, add an additional partial buffered request + * @param responseRequiredIndex Index of the buffered request for which a response is sent. Previous requests + * are completed without a response. If set to -1, all `numComplete` requests + * are completed without a response. + * @param makeClosing If true, put the channel into closing state in the server Selector. + */ + private def verifyRemoteCloseWithBufferedReceives(numComplete: Int, + hasIncomplete: Boolean, + responseRequiredIndex: Int = -1, + makeClosing: Boolean = false): Unit = { + props ++= sslServerProps + + // Truncates the last request in the SSL buffers by directly updating the buffers to simulate partial buffered request + def truncateBufferedRequest(channel: KafkaChannel): Unit = { + val transportLayer: SslTransportLayer = JTestUtils.fieldValue(channel, classOf[KafkaChannel], "transportLayer") + val netReadBuffer: ByteBuffer = JTestUtils.fieldValue(transportLayer, classOf[SslTransportLayer], "netReadBuffer") + val appReadBuffer: ByteBuffer = JTestUtils.fieldValue(transportLayer, classOf[SslTransportLayer], "appReadBuffer") + if (appReadBuffer.position() > 4) { + appReadBuffer.position(4) + netReadBuffer.position(0) + } else { + netReadBuffer.position(20) + } + } + withTestableServer (testWithServer = { testableServer => + val testableSelector = testableServer.testableSelector + + val proxyServer = new ProxyServer(testableServer) + try { + // Step 1: Send client requests. + // a) request1 is sent by the client to ProxyServer and this is directly sent to the server. This + // ensures that server-side channel is in muted state until this request is processed in Step 3. + // b) `numComplete` requests are sent and buffered in the server-side channel's SSL buffers + // c) If `hasIncomplete=true`, an extra request is sent and buffered as in b). This will be truncated later + // when previous requests have been processed and only one request is remaining in the SSL buffer, + // making it easy to truncate. + val numBufferedRequests = numComplete + (if (hasIncomplete) 1 else 0) + val (socket, request1) = makeSocketWithBufferedRequests(testableServer, testableSelector, proxyServer, numBufferedRequests) + val channel = openChannel(request1, testableServer).getOrElse(throw new IllegalStateException("Channel closed too early")) + + // Step 2: Close the client-side socket and the proxy socket to the server, triggering close notification in the + // server when the client is unmuted in Step 3. Get the channel into its desired closing/buffered state. + socket.close() + proxyServer.serverConnSocket.close() + TestUtils.waitUntilTrue(() => proxyServer.clientConnSocket.isClosed, "Client socket not closed") + if (makeClosing) + testableSelector.pendingClosingChannels.add(channel) + if (numComplete == 0 && hasIncomplete) + truncateBufferedRequest(channel) + + // Step 3: Process the first request. Verify that the channel is not removed since the channel + // should be retained to process buffered data. + processRequestNoOpResponse(testableServer.dataPlaneRequestChannel, request1) + assertSame(channel, openOrClosingChannel(request1, testableServer).getOrElse(throw new IllegalStateException("Channel closed too early"))) + + // Step 4: Process buffered data. if `responseRequiredIndex>=0`, the channel should be failed and removed when + // attempting to send response. Otherwise, the channel should be removed when all completed buffers are processed. + // Channel should be closed and removed even if there is a partial buffered request when `hasIncomplete=true` + val numRequests = if (responseRequiredIndex >= 0) responseRequiredIndex + 1 else numComplete + (0 until numRequests).foreach { i => + val request = receiveRequest(testableServer.dataPlaneRequestChannel) + if (i == numComplete - 1 && hasIncomplete) + truncateBufferedRequest(channel) + if (responseRequiredIndex == i) + processRequest(testableServer.dataPlaneRequestChannel, request) + else + processRequestNoOpResponse(testableServer.dataPlaneRequestChannel, request) + } + testableServer.waitForChannelClose(channel.id, locallyClosed = false) + + // Verify that SocketServer is healthy + val anotherSocket = sslConnect(testableServer) + assertProcessorHealthy(testableServer, Seq(anotherSocket)) + } finally { + proxyServer.close() + } + }) + } + + /** + * Tests idle channel expiry for SSL channels with buffered data. Muted channels are expired + * immediately even if there is pending data to be processed. This is consistent with PLAINTEXT where + * we expire muted channels even if there is data available on the socket. This scenario occurs if broker + * takes longer than idle timeout to process a client request. In this case, typically client would have + * expired its connection and would potentially reconnect to retry the request, so immediate expiry enables + * the old connection and its associated resources to be freed sooner. + */ + @Test + def idleExpiryWithBufferedReceives(): Unit = { + val idleTimeMs = 60000 + val time = new MockTime() + props.put(KafkaConfig.ConnectionsMaxIdleMsProp, idleTimeMs.toString) + props ++= sslServerProps + val testableServer = new TestableSocketServer(time = time) + testableServer.startup() + val proxyServer = new ProxyServer(testableServer) + try { + val testableSelector = testableServer.testableSelector testableSelector.updateMinWakeup(2) - val sockets = (1 to 2).map(_ => connect(testableServer)) - val serializedBytes = producerRequestBytes() - val request = sendRequestsUntilStagedReceive(testableServer, sockets(0), serializedBytes) - sendRequest(sockets(1), serializedBytes) + val (socket, request) = makeSocketWithBufferedRequests(testableServer, testableSelector, proxyServer) + time.sleep(idleTimeMs + 1) + testableServer.waitForChannelClose(request.context.connectionId, locallyClosed = false) - testableSelector.addFailure(SelectorOperation.Send) - sockets(0).close() - processRequest(testableServer.dataPlaneRequestChannel, request) - processRequest(testableServer.dataPlaneRequestChannel) // Also process request from other channel - testableSelector.waitForOperations(SelectorOperation.Send, 2) - testableServer.waitForChannelClose(request.context.connectionId, locallyClosed = true) + val otherSocket = sslConnect(testableServer) + assertProcessorHealthy(testableServer, Seq(otherSocket)) - assertProcessorHealthy(testableServer, Seq(sockets(1))) - }) + socket.close() + } finally { + proxyServer.close() + shutdownServerAndMetrics(testableServer) + } } /** @@ -1345,8 +1631,16 @@ class SocketServerTest { } } - private def withTestableServer(config : KafkaConfig = config, testWithServer: TestableSocketServer => Unit): Unit = { - props.put("listeners", "PLAINTEXT://localhost:0") + private def sslServerProps: Properties = { + val trustStoreFile = File.createTempFile("truststore", ".jks") + val sslProps = TestUtils.createBrokerConfig(0, TestUtils.MockZkConnect, interBrokerSecurityProtocol = Some(SecurityProtocol.SSL), + trustStoreFile = Some(trustStoreFile)) + sslProps.put(KafkaConfig.ListenersProp, "SSL://localhost:0") + sslProps + } + + private def withTestableServer(config : KafkaConfig = KafkaConfig.fromProps(props), + testWithServer: TestableSocketServer => Unit): Unit = { val testableServer = new TestableSocketServer(config) testableServer.startup() try { @@ -1402,7 +1696,7 @@ class SocketServerTest { } } - class TestableSocketServer(config : KafkaConfig = config, val connectionQueueSize: Int = 20, + class TestableSocketServer(config : KafkaConfig = KafkaConfig.fromProps(props), val connectionQueueSize: Int = 20, override val time: Time = Time.SYSTEM) extends SocketServer(config, new Metrics, time, credentialProvider) { @@ -1493,6 +1787,7 @@ class SocketServerTest { val cachedCompletedSends = new PollData[Send]() val cachedDisconnected = new PollData[(String, ChannelState)]() val allCachedPollData = Seq(cachedCompletedReceives, cachedCompletedSends, cachedDisconnected) + val pendingClosingChannels = new ConcurrentLinkedQueue[KafkaChannel]() @volatile var minWakeupCount = 0 @volatile var pollTimeoutOverride: Option[Long] = None @volatile var pollCallback: () => Unit = () => {} @@ -1538,6 +1833,9 @@ class SocketServerTest { override def poll(timeout: Long): Unit = { try { pollCallback.apply() + while (!pendingClosingChannels.isEmpty) { + makeClosing(pendingClosingChannels.poll()) + } allCachedPollData.foreach(_.reset) runOp(SelectorOperation.Poll, None) { super.poll(pollTimeoutOverride.getOrElse(timeout)) @@ -1545,7 +1843,7 @@ class SocketServerTest { } finally { super.channels.asScala.foreach(allChannels += _.id) allDisconnectedChannels ++= super.disconnected.asScala.keys - cachedCompletedReceives.update(super.completedReceives.asScala) + cachedCompletedReceives.update(super.completedReceives.asScala.toBuffer) cachedCompletedSends.update(super.completedSends.asScala) cachedDisconnected.update(super.disconnected.asScala.toBuffer) } @@ -1611,5 +1909,67 @@ class SocketServerTest { val failedConnectionId = allFailedChannels.head sockets.filterNot(socket => isSocketConnectionId(failedConnectionId, socket)) } + + private def makeClosing(channel: KafkaChannel): Unit = { + val channels: util.Map[String, KafkaChannel] = JTestUtils.fieldValue(this, classOf[Selector], "channels") + val closingChannels: util.Map[String, KafkaChannel] = JTestUtils.fieldValue(this, classOf[Selector], "closingChannels") + closingChannels.put(channel.id, channel) + channels.remove(channel.id) + } + } + + /** + * Proxy server used to intercept connections to SocketServer. This is used for testing SSL channels + * with buffered data. A single SSL client is expected to be created by the test using this ProxyServer. + * By default, data between the client and the server is simply transferred across to the destination by ProxyServer. + * Tests can enable buffering in ProxyServer to directly copy incoming data from the client to the server-side + * channel's `netReadBuffer` to simulate scenarios with SSL buffered data. + */ + private class ProxyServer(socketServer: SocketServer) { + val serverSocket = new ServerSocket(0) + val localPort = serverSocket.getLocalPort + val serverConnSocket = new Socket("localhost", socketServer.boundPort(ListenerName.forSecurityProtocol(SecurityProtocol.SSL))) + val executor = Executors.newFixedThreadPool(2) + @volatile var clientConnSocket: Socket = _ + @volatile var buffer: Option[ByteBuffer] = None + + executor.submit((() => { + try { + clientConnSocket = serverSocket.accept() + val serverOut = serverConnSocket.getOutputStream + val clientIn = clientConnSocket.getInputStream + var b: Int = -1 + while ({b = clientIn.read(); b != -1}) { + buffer match { + case Some(buf) => + buf.put(b.asInstanceOf[Byte]) + case None => + serverOut.write(b) + serverOut.flush() + } + } + } finally { + clientConnSocket.close() + } + }): Runnable) + + executor.submit((() => { + var b: Int = -1 + val serverIn = serverConnSocket.getInputStream + while ({b = serverIn.read(); b != -1}) { + clientConnSocket.getOutputStream.write(b) + } + }): Runnable) + + def enableBuffering(buffer: ByteBuffer): Unit = this.buffer = Some(buffer) + + def close(): Unit = { + serverSocket.close() + serverConnSocket.close() + clientConnSocket.close() + executor.shutdownNow() + assertTrue(executor.awaitTermination(10, TimeUnit.SECONDS)) + } + } }