KAFKA-7352; Allow SASL Connections to Periodically Re-Authenticate (KIP-368) (#5582)

KIP-368 implementation to enable periodic re-authentication of SASL clients. Also adds a broker configuration option to terminate client connections that do not re-authenticate within the configured interval.
This commit is contained in:
Ron Dagostino 2018-10-26 18:18:15 -04:00 committed by Rajini Sivaram
parent 51061792ca
commit e8a3bc7425
52 changed files with 1938 additions and 223 deletions

View File

@ -51,7 +51,7 @@
files="(Utils|Topic|KafkaLZ4BlockOutputStream|AclData).java"/> files="(Utils|Topic|KafkaLZ4BlockOutputStream|AclData).java"/>
<suppress checks="CyclomaticComplexity" <suppress checks="CyclomaticComplexity"
files="(ConsumerCoordinator|Fetcher|Sender|KafkaProducer|BufferPool|ConfigDef|RecordAccumulator|KerberosLogin|AbstractRequest|AbstractResponse|Selector|SslFactory|SslTransportLayer|SaslClientAuthenticator|SaslClientCallbackHandler).java"/> files="(ConsumerCoordinator|Fetcher|Sender|KafkaProducer|BufferPool|ConfigDef|RecordAccumulator|KerberosLogin|AbstractRequest|AbstractResponse|Selector|SslFactory|SslTransportLayer|SaslClientAuthenticator|SaslClientCallbackHandler|SaslServerAuthenticator).java"/>
<suppress checks="JavaNCSS" <suppress checks="JavaNCSS"
files="AbstractRequest.java|KerberosLogin.java|WorkerSinkTaskTest.java|TransactionManagerTest.java"/> files="AbstractRequest.java|KerberosLogin.java|WorkerSinkTaskTest.java|TransactionManagerTest.java"/>

View File

@ -23,6 +23,7 @@ import org.apache.kafka.common.network.ChannelBuilder;
import org.apache.kafka.common.network.ChannelBuilders; import org.apache.kafka.common.network.ChannelBuilders;
import org.apache.kafka.common.security.JaasContext; import org.apache.kafka.common.security.JaasContext;
import org.apache.kafka.common.security.auth.SecurityProtocol; import org.apache.kafka.common.security.auth.SecurityProtocol;
import org.apache.kafka.common.utils.Time;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -102,11 +103,11 @@ public final class ClientUtils {
* @param config client configs * @param config client configs
* @return configured ChannelBuilder based on the configs. * @return configured ChannelBuilder based on the configs.
*/ */
public static ChannelBuilder createChannelBuilder(AbstractConfig config) { public static ChannelBuilder createChannelBuilder(AbstractConfig config, Time time) {
SecurityProtocol securityProtocol = SecurityProtocol.forName(config.getString(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG)); SecurityProtocol securityProtocol = SecurityProtocol.forName(config.getString(CommonClientConfigs.SECURITY_PROTOCOL_CONFIG));
String clientSaslMechanism = config.getString(SaslConfigs.SASL_MECHANISM); String clientSaslMechanism = config.getString(SaslConfigs.SASL_MECHANISM);
return ChannelBuilders.clientChannelBuilder(securityProtocol, JaasContext.Type.CLIENT, config, null, return ChannelBuilders.clientChannelBuilder(securityProtocol, JaasContext.Type.CLIENT, config, null,
clientSaslMechanism, true); clientSaslMechanism, time, true);
} }
static List<InetAddress> resolve(String host, ClientDnsLookup clientDnsLookup) throws UnknownHostException { static List<InetAddress> resolve(String host, ClientDnsLookup clientDnsLookup) throws UnknownHostException {

View File

@ -350,7 +350,7 @@ public class KafkaAdminClient extends AdminClient {
reporters.add(new JmxReporter(JMX_PREFIX)); reporters.add(new JmxReporter(JMX_PREFIX));
metrics = new Metrics(metricConfig, reporters, time); metrics = new Metrics(metricConfig, reporters, time);
String metricGrpPrefix = "admin-client"; String metricGrpPrefix = "admin-client";
channelBuilder = ClientUtils.createChannelBuilder(config); channelBuilder = ClientUtils.createChannelBuilder(config, time);
selector = new Selector(config.getLong(AdminClientConfig.CONNECTIONS_MAX_IDLE_MS_CONFIG), selector = new Selector(config.getLong(AdminClientConfig.CONNECTIONS_MAX_IDLE_MS_CONFIG),
metrics, time, metricGrpPrefix, channelBuilder, logContext); metrics, time, metricGrpPrefix, channelBuilder, logContext);
networkClient = new NetworkClient( networkClient = new NetworkClient(

View File

@ -715,7 +715,7 @@ public class KafkaConsumer<K, V> implements Consumer<K, V> {
this.metadata.update(Cluster.bootstrap(addresses), Collections.<String>emptySet(), 0); this.metadata.update(Cluster.bootstrap(addresses), Collections.<String>emptySet(), 0);
String metricGrpPrefix = "consumer"; String metricGrpPrefix = "consumer";
ConsumerMetrics metricsRegistry = new ConsumerMetrics(metricsTags.keySet(), "consumer"); ConsumerMetrics metricsRegistry = new ConsumerMetrics(metricsTags.keySet(), "consumer");
ChannelBuilder channelBuilder = ClientUtils.createChannelBuilder(config); ChannelBuilder channelBuilder = ClientUtils.createChannelBuilder(config, time);
IsolationLevel isolationLevel = IsolationLevel.valueOf( IsolationLevel isolationLevel = IsolationLevel.valueOf(
config.getString(ConsumerConfig.ISOLATION_LEVEL_CONFIG).toUpperCase(Locale.ROOT)); config.getString(ConsumerConfig.ISOLATION_LEVEL_CONFIG).toUpperCase(Locale.ROOT));

View File

@ -438,7 +438,7 @@ public class KafkaProducer<K, V> implements Producer<K, V> {
Sender newSender(LogContext logContext, KafkaClient kafkaClient, Metadata metadata) { Sender newSender(LogContext logContext, KafkaClient kafkaClient, Metadata metadata) {
int maxInflightRequests = configureInflightRequests(producerConfig, transactionManager != null); int maxInflightRequests = configureInflightRequests(producerConfig, transactionManager != null);
int requestTimeoutMs = producerConfig.getInt(ProducerConfig.REQUEST_TIMEOUT_MS_CONFIG); int requestTimeoutMs = producerConfig.getInt(ProducerConfig.REQUEST_TIMEOUT_MS_CONFIG);
ChannelBuilder channelBuilder = ClientUtils.createChannelBuilder(producerConfig); ChannelBuilder channelBuilder = ClientUtils.createChannelBuilder(producerConfig, time);
ProducerMetrics metricsRegistry = new ProducerMetrics(this.metrics); ProducerMetrics metricsRegistry = new ProducerMetrics(this.metrics);
Sensor throttleTimeSensor = Sender.throttleTimeSensor(metricsRegistry.senderMetrics); Sensor throttleTimeSensor = Sender.throttleTimeSensor(metricsRegistry.senderMetrics);
KafkaClient client = kafkaClient != null ? kafkaClient : new NetworkClient( KafkaClient client = kafkaClient != null ? kafkaClient : new NetworkClient(

View File

@ -35,6 +35,7 @@ public class BrokerSecurityConfigs {
public static final String SASL_ENABLED_MECHANISMS_CONFIG = "sasl.enabled.mechanisms"; public static final String SASL_ENABLED_MECHANISMS_CONFIG = "sasl.enabled.mechanisms";
public static final String SASL_SERVER_CALLBACK_HANDLER_CLASS = "sasl.server.callback.handler.class"; public static final String SASL_SERVER_CALLBACK_HANDLER_CLASS = "sasl.server.callback.handler.class";
public static final String SSL_PRINCIPAL_MAPPING_RULES_CONFIG = "ssl.principal.mapping.rules"; public static final String SSL_PRINCIPAL_MAPPING_RULES_CONFIG = "ssl.principal.mapping.rules";
public static final String CONNECTIONS_MAX_REAUTH_MS = "connections.max.reauth.ms";
public static final String PRINCIPAL_BUILDER_CLASS_DOC = "The fully qualified name of a class that implements the " + public static final String PRINCIPAL_BUILDER_CLASS_DOC = "The fully qualified name of a class that implements the " +
"KafkaPrincipalBuilder interface, which is used to build the KafkaPrincipal object used during " + "KafkaPrincipalBuilder interface, which is used to build the KafkaPrincipal object used during " +
@ -84,4 +85,9 @@ public class BrokerSecurityConfigs {
+ "listener prefix and SASL mechanism name in lower-case. For example, " + "listener prefix and SASL mechanism name in lower-case. For example, "
+ "listener.name.sasl_ssl.plain.sasl.server.callback.handler.class=com.example.CustomPlainCallbackHandler."; + "listener.name.sasl_ssl.plain.sasl.server.callback.handler.class=com.example.CustomPlainCallbackHandler.";
public static final String CONNECTIONS_MAX_REAUTH_MS_DOC = "When explicitly set to a positive number (the default is 0, not a positive number), "
+ "a session lifetime that will not exceed the configured value will be communicated to v2.2.0 or later clients when they authenticate. "
+ "The broker will disconnect any such connection that is not re-authenticated within the session lifetime and that is then subsequently "
+ "used for any purpose other than re-authentication. Configuration names can optionally be prefixed with listener prefix and SASL "
+ "mechanism name in lower-case. For example, listener.name.sasl_ssl.oauthbearer.connections.max.reauth.ms=3600000";
} }

View File

@ -21,6 +21,8 @@ import org.apache.kafka.common.security.auth.KafkaPrincipal;
import java.io.Closeable; import java.io.Closeable;
import java.io.IOException; import java.io.IOException;
import java.util.Collections;
import java.util.List;
/** /**
* Authentication for Channel * Authentication for Channel
@ -54,4 +56,104 @@ public interface Authenticator extends Closeable {
* returns true if authentication is complete otherwise returns false; * returns true if authentication is complete otherwise returns false;
*/ */
boolean complete(); boolean complete();
/**
* Begins re-authentication. Uses transportLayer to read or write tokens as is
* done for {@link #authenticate()}. For security protocols PLAINTEXT and SSL,
* this is a no-op since re-authentication does not apply/is not supported,
* respectively. For SASL_PLAINTEXT and SASL_SSL, this performs a SASL
* authentication. Any in-flight responses from prior requests can/will be read
* and collected for later processing as required. There must not be partially
* written requests; any request queued for writing (for which zero bytes have
* been written) remains queued until after re-authentication succeeds.
*
* @param reauthenticationContext
* the context in which this re-authentication is occurring. This
* instance is responsible for closing the previous Authenticator
* returned by
* {@link ReauthenticationContext#previousAuthenticator()}.
* @throws AuthenticationException
* if authentication fails due to invalid credentials or other
* security configuration errors
* @throws IOException
* if read/write fails due to an I/O error
*/
default void reauthenticate(ReauthenticationContext reauthenticationContext) throws IOException {
// empty
}
/**
* Return the session expiration time, if any, otherwise null. The value is in
* nanoseconds as per {@code System.nanoTime()} and is therefore only useful
* when compared to such a value -- it's absolute value is meaningless. This
* value may be non-null only on the server-side. It represents the time after
* which, in the absence of re-authentication, the broker will close the session
* if it receives a request unrelated to authentication. We store nanoseconds
* here to avoid having to invoke the more expensive {@code milliseconds()} call
* on the broker for every request
*
* @return the session expiration time, if any, otherwise null
*/
default Long serverSessionExpirationTimeNanos() {
return null;
}
/**
* Return the time on or after which a client should re-authenticate this
* session, if any, otherwise null. The value is in nanoseconds as per
* {@code System.nanoTime()} and is therefore only useful when compared to such
* a value -- it's absolute value is meaningless. This value may be non-null
* only on the client-side. It will be a random time between 85% and 95% of the
* full session lifetime to account for latency between client and server and to
* avoid re-authentication storms that could be caused by many sessions
* re-authenticating simultaneously.
*
* @return the time on or after which a client should re-authenticate this
* session, if any, otherwise null
*/
default Long clientSessionReauthenticationTimeNanos() {
return null;
}
/**
* Return the number of milliseconds that elapsed while re-authenticating this
* session from the perspective of this instance, if applicable, otherwise null.
* The server-side perspective will yield a lower value than the client-side
* perspective of the same re-authentication because the client-side observes an
* additional network round-trip.
*
* @return the number of milliseconds that elapsed while re-authenticating this
* session from the perspective of this instance, if applicable,
* otherwise null
*/
default Long reauthenticationLatencyMs() {
return null;
}
/**
* Return the (always non-null but possibly empty) client-side
* {@link NetworkReceive} responses that arrived during re-authentication that
* are unrelated to re-authentication, if any. These correspond to requests sent
* prior to the beginning of re-authentication; the requests were made when the
* channel was successfully authenticated, and the responses arrived during the
* re-authentication process.
*
* @return the (always non-null but possibly empty) client-side
* {@link NetworkReceive} responses that arrived during
* re-authentication that are unrelated to re-authentication, if any
*/
default List<NetworkReceive> getAndClearResponsesReceivedDuringReauthentication() {
return Collections.emptyList();
}
/**
* Return true if this is a server-side authenticator and the connected client
* has indicated that it supports re-authentication, otherwise false
*
* @return true if this is a server-side authenticator and the connected client
* has indicated that it supports re-authentication, otherwise false
*/
default boolean connectedClientSupportsReauthentication() {
return false;
}
} }

View File

@ -28,6 +28,7 @@ import org.apache.kafka.common.security.authenticator.CredentialCache;
import org.apache.kafka.common.security.kerberos.KerberosShortNamer; import org.apache.kafka.common.security.kerberos.KerberosShortNamer;
import org.apache.kafka.common.security.ssl.SslPrincipalMapper; import org.apache.kafka.common.security.ssl.SslPrincipalMapper;
import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache; import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache;
import org.apache.kafka.common.utils.Time;
import org.apache.kafka.common.utils.Utils; import org.apache.kafka.common.utils.Utils;
import java.util.Collections; import java.util.Collections;
@ -36,7 +37,6 @@ import java.util.List;
import java.util.Map; import java.util.Map;
public class ChannelBuilders { public class ChannelBuilders {
private ChannelBuilders() { } private ChannelBuilders() { }
/** /**
@ -55,6 +55,7 @@ public class ChannelBuilders {
AbstractConfig config, AbstractConfig config,
ListenerName listenerName, ListenerName listenerName,
String clientSaslMechanism, String clientSaslMechanism,
Time time,
boolean saslHandshakeRequestEnable) { boolean saslHandshakeRequestEnable) {
if (securityProtocol == SecurityProtocol.SASL_PLAINTEXT || securityProtocol == SecurityProtocol.SASL_SSL) { if (securityProtocol == SecurityProtocol.SASL_PLAINTEXT || securityProtocol == SecurityProtocol.SASL_SSL) {
@ -64,7 +65,7 @@ public class ChannelBuilders {
throw new IllegalArgumentException("`clientSaslMechanism` must be non-null in client mode if `securityProtocol` is `" + securityProtocol + "`"); throw new IllegalArgumentException("`clientSaslMechanism` must be non-null in client mode if `securityProtocol` is `" + securityProtocol + "`");
} }
return create(securityProtocol, Mode.CLIENT, contextType, config, listenerName, false, clientSaslMechanism, return create(securityProtocol, Mode.CLIENT, contextType, config, listenerName, false, clientSaslMechanism,
saslHandshakeRequestEnable, null, null); saslHandshakeRequestEnable, null, null, time);
} }
/** /**
@ -79,9 +80,10 @@ public class ChannelBuilders {
SecurityProtocol securityProtocol, SecurityProtocol securityProtocol,
AbstractConfig config, AbstractConfig config,
CredentialCache credentialCache, CredentialCache credentialCache,
DelegationTokenCache tokenCache) { DelegationTokenCache tokenCache,
Time time) {
return create(securityProtocol, Mode.SERVER, JaasContext.Type.SERVER, config, listenerName, return create(securityProtocol, Mode.SERVER, JaasContext.Type.SERVER, config, listenerName,
isInterBrokerListener, null, true, credentialCache, tokenCache); isInterBrokerListener, null, true, credentialCache, tokenCache, time);
} }
private static ChannelBuilder create(SecurityProtocol securityProtocol, private static ChannelBuilder create(SecurityProtocol securityProtocol,
@ -93,7 +95,8 @@ public class ChannelBuilders {
String clientSaslMechanism, String clientSaslMechanism,
boolean saslHandshakeRequestEnable, boolean saslHandshakeRequestEnable,
CredentialCache credentialCache, CredentialCache credentialCache,
DelegationTokenCache tokenCache) { DelegationTokenCache tokenCache,
Time time) {
Map<String, ?> configs; Map<String, ?> configs;
if (listenerName == null) if (listenerName == null)
configs = config.values(); configs = config.values();
@ -111,6 +114,7 @@ public class ChannelBuilders {
requireNonNullMode(mode, securityProtocol); requireNonNullMode(mode, securityProtocol);
Map<String, JaasContext> jaasContexts; Map<String, JaasContext> jaasContexts;
if (mode == Mode.SERVER) { if (mode == Mode.SERVER) {
@SuppressWarnings("unchecked")
List<String> enabledMechanisms = (List<String>) configs.get(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG); List<String> enabledMechanisms = (List<String>) configs.get(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG);
jaasContexts = new HashMap<>(enabledMechanisms.size()); jaasContexts = new HashMap<>(enabledMechanisms.size());
for (String mechanism : enabledMechanisms) for (String mechanism : enabledMechanisms)
@ -129,7 +133,8 @@ public class ChannelBuilders {
clientSaslMechanism, clientSaslMechanism,
saslHandshakeRequestEnable, saslHandshakeRequestEnable,
credentialCache, credentialCache,
tokenCache); tokenCache,
time);
break; break;
case PLAINTEXT: case PLAINTEXT:
channelBuilder = new PlaintextChannelBuilder(listenerName); channelBuilder = new PlaintextChannelBuilder(listenerName);

View File

@ -27,9 +27,45 @@ import java.net.InetAddress;
import java.net.Socket; import java.net.Socket;
import java.nio.channels.SelectionKey; import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel; import java.nio.channels.SocketChannel;
import java.util.List;
import java.util.Objects; import java.util.Objects;
import java.util.function.Supplier;
/**
* A Kafka connection either existing on a client (which could be a broker in an
* inter-broker scenario) and representing the channel to a remote broker or the
* reverse (existing on a broker and representing the channel to a remote
* client, which could be a broker in an inter-broker scenario).
* <p>
* Each instance has the following:
* <ul>
* <li>a unique ID identifying it in the {@code KafkaClient} instance via which
* the connection was made on the client-side or in the instance where it was
* accepted on the server-side</li>
* <li>a reference to the underlying {@link TransportLayer} to allow reading and
* writing</li>
* <li>an {@link Authenticator} that performs the authentication (or
* re-authentication, if that feature is enabled and it applies to this
* connection) by reading and writing directly from/to the same
* {@link TransportLayer}.</li>
* <li>a {@link MemoryPool} into which responses are read (typically the JVM
* heap for clients, though smaller pools can be used for brokers and for
* testing out-of-memory scenarios)</li>
* <li>a {@link NetworkReceive} representing the current incomplete/in-progress
* request (from the server-side perspective) or response (from the client-side
* perspective) being read, if applicable; or a non-null value that has had no
* data read into it yet or a null value if there is no in-progress
* request/response (either could be the case)</li>
* <li>a {@link Send} representing the current request (from the client-side
* perspective) or response (from the server-side perspective) that is either
* waiting to be sent or partially sent, if applicable, or null</li>
* <li>a {@link ChannelMuteState} to document if the channel has been muted due
* to memory pressure or other reasons</li>
* </ul>
*/
public class KafkaChannel { public class KafkaChannel {
private static final long MIN_REAUTH_INTERVAL_ONE_SECOND_NANOS = 1000 * 1000 * 1000;
/** /**
* Mute States for KafkaChannel: * Mute States for KafkaChannel:
* <ul> * <ul>
@ -78,7 +114,8 @@ public class KafkaChannel {
private final String id; private final String id;
private final TransportLayer transportLayer; private final TransportLayer transportLayer;
private final Authenticator authenticator; private final Supplier<Authenticator> authenticatorCreator;
private Authenticator authenticator;
// Tracks accumulated network thread time. This is updated on the network thread. // Tracks accumulated network thread time. This is updated on the network thread.
// The values are read and reset after each response is sent. // The values are read and reset after each response is sent.
private long networkThreadTimeNanos; private long networkThreadTimeNanos;
@ -92,11 +129,15 @@ public class KafkaChannel {
private ChannelMuteState muteState; private ChannelMuteState muteState;
private ChannelState state; private ChannelState state;
private SocketAddress remoteAddress; private SocketAddress remoteAddress;
private int successfulAuthentications;
private boolean midWrite;
private long lastReauthenticationStartNanos;
public KafkaChannel(String id, TransportLayer transportLayer, Authenticator authenticator, int maxReceiveSize, MemoryPool memoryPool) { public KafkaChannel(String id, TransportLayer transportLayer, Supplier<Authenticator> authenticatorCreator, int maxReceiveSize, MemoryPool memoryPool) {
this.id = id; this.id = id;
this.transportLayer = transportLayer; this.transportLayer = transportLayer;
this.authenticator = authenticator; this.authenticatorCreator = authenticatorCreator;
this.authenticator = authenticatorCreator.get();
this.networkThreadTimeNanos = 0L; this.networkThreadTimeNanos = 0L;
this.maxReceiveSize = maxReceiveSize; this.maxReceiveSize = maxReceiveSize;
this.memoryPool = memoryPool; this.memoryPool = memoryPool;
@ -142,8 +183,10 @@ public class KafkaChannel {
} }
throw e; throw e;
} }
if (ready()) if (ready()) {
++successfulAuthentications;
state = ChannelState.READY; state = ChannelState.READY;
}
} }
public void disconnect() { public void disconnect() {
@ -382,10 +425,12 @@ public class KafkaChannel {
} }
private boolean send(Send send) throws IOException { private boolean send(Send send) throws IOException {
midWrite = true;
send.writeTo(transportLayer); send.writeTo(transportLayer);
if (send.completed()) if (send.completed()) {
midWrite = false;
transportLayer.removeInterestOps(SelectionKey.OP_WRITE); transportLayer.removeInterestOps(SelectionKey.OP_WRITE);
}
return send.completed(); return send.completed();
} }
@ -412,4 +457,189 @@ public class KafkaChannel {
public int hashCode() { public int hashCode() {
return Objects.hash(id); return Objects.hash(id);
} }
@Override
public String toString() {
return super.toString() + " id=" + id;
}
/**
* Return the number of times this instance has successfully authenticated. This
* value can only exceed 1 when re-authentication is enabled and it has
* succeeded at least once.
*
* @return the number of times this instance has successfully authenticated
*/
public int successfulAuthentications() {
return successfulAuthentications;
}
/**
* If this is a server-side connection that has an expiration time and at least
* 1 second has passed since the prior re-authentication (if any) started then
* begin the process of re-authenticating the connection and return true,
* otherwise return false
*
* @param saslHandshakeNetworkReceive
* the mandatory {@link NetworkReceive} containing the
* {@code SaslHandshakeRequest} that has been received on the server
* and that initiates re-authentication.
* @param nowNanosSupplier
* {@code Supplier} of the current time. The value must be in
* nanoseconds as per {@code System.nanoTime()} and is therefore only
* useful when compared to such a value -- it's absolute value is
* meaningless.
*
* @return true if this is a server-side connection that has an expiration time
* and at least 1 second has passed since the prior re-authentication
* (if any) started to indicate that the re-authentication process has
* begun, otherwise false
* @throws AuthenticationException
* if re-authentication fails due to invalid credentials or other
* security configuration errors
* @throws IOException
* if read/write fails due to an I/O error
* @throws IllegalStateException
* if this channel is not "ready"
*/
public boolean maybeBeginServerReauthentication(NetworkReceive saslHandshakeNetworkReceive,
Supplier<Long> nowNanosSupplier) throws AuthenticationException, IOException {
if (!ready())
throw new IllegalStateException(
"KafkaChannel should be \"ready\" when processing SASL Handshake for potential re-authentication");
/*
* Re-authentication is disabled if there is no session expiration time, in
* which case the SASL handshake network receive will be processed normally,
* which results in a failure result being sent to the client. Also, no need to
* check if we are muted since since we are processing a received packet when we
* invoke this.
*/
if (authenticator.serverSessionExpirationTimeNanos() == null)
return false;
/*
* We've delayed getting the time as long as possible in case we don't need it,
* but at this point we need it -- so get it now.
*/
long nowNanos = nowNanosSupplier.get().longValue();
/*
* Cannot re-authenticate more than once every second; an attempt to do so will
* result in the SASL handshake network receive being processed normally, which
* results in a failure result being sent to the client.
*/
if (lastReauthenticationStartNanos != 0
&& nowNanos - lastReauthenticationStartNanos < MIN_REAUTH_INTERVAL_ONE_SECOND_NANOS)
return false;
lastReauthenticationStartNanos = nowNanos;
swapAuthenticatorsAndBeginReauthentication(
new ReauthenticationContext(authenticator, saslHandshakeNetworkReceive, nowNanos));
return true;
}
/**
* If this is a client-side connection that is not muted, there is no
* in-progress write, and there is a session expiration time defined that has
* past then begin the process of re-authenticating the connection and return
* true, otherwise return false
*
* @param nowNanosSupplier
* {@code Supplier} of the current time. The value must be in
* nanoseconds as per {@code System.nanoTime()} and is therefore only
* useful when compared to such a value -- it's absolute value is
* meaningless.
*
* @return true if this is a client-side connection that is not muted, there is
* no in-progress write, and there is a session expiration time defined
* that has past to indicate that the re-authentication process has
* begun, otherwise false
* @throws AuthenticationException
* if re-authentication fails due to invalid credentials or other
* security configuration errors
* @throws IOException
* if read/write fails due to an I/O error
* @throws IllegalStateException
* if this channel is not "ready"
*/
public boolean maybeBeginClientReauthentication(Supplier<Long> nowNanosSupplier)
throws AuthenticationException, IOException {
if (!ready())
throw new IllegalStateException(
"KafkaChannel should always be \"ready\" when it is checked for possible re-authentication");
if (muteState != ChannelMuteState.NOT_MUTED || midWrite
|| authenticator.clientSessionReauthenticationTimeNanos() == null)
return false;
/*
* We've delayed getting the time as long as possible in case we don't need it,
* but at this point we need it -- so get it now.
*/
long nowNanos = nowNanosSupplier.get().longValue();
if (nowNanos < authenticator.clientSessionReauthenticationTimeNanos().longValue())
return false;
swapAuthenticatorsAndBeginReauthentication(new ReauthenticationContext(authenticator, receive, nowNanos));
receive = null;
return true;
}
/**
* Return the number of milliseconds that elapsed while re-authenticating this
* session from the perspective of this instance, if applicable, otherwise null.
* The server-side perspective will yield a lower value than the client-side
* perspective of the same re-authentication because the client-side observes an
* additional network round-trip.
*
* @return the number of milliseconds that elapsed while re-authenticating this
* session from the perspective of this instance, if applicable,
* otherwise null
*/
public Long reauthenticationLatencyMs() {
return authenticator.reauthenticationLatencyMs();
}
/**
* Return true if this is a server-side channel and the given time is past the
* session expiration time, if any, otherwise false
*
* @param nowNanos
* the current time in nanoseconds as per {@code System.nanoTime()}
* @return true if this is a server-side channel and the given time is past the
* session expiration time, if any, otherwise false
*/
public boolean serverAuthenticationSessionExpired(long nowNanos) {
Long serverSessionExpirationTimeNanos = authenticator.serverSessionExpirationTimeNanos();
return serverSessionExpirationTimeNanos != null && nowNanos - serverSessionExpirationTimeNanos.longValue() > 0;
}
/**
* Return the (always non-null but possibly empty) client-side
* {@link NetworkReceive} responses that arrived during re-authentication that
* are unrelated to re-authentication, if any. These correspond to requests sent
* prior to the beginning of re-authentication; the requests were made when the
* channel was successfully authenticated, and the responses arrived during the
* re-authentication process.
*
* @return the (always non-null but possibly empty) client-side
* {@link NetworkReceive} responses that arrived during
* re-authentication that are unrelated to re-authentication, if any
*/
public List<NetworkReceive> getAndClearResponsesReceivedDuringReauthentication() {
return authenticator.getAndClearResponsesReceivedDuringReauthentication();
}
/**
* Return true if this is a server-side channel and the connected client has
* indicated that it supports re-authentication, otherwise false
*
* @return true if this is a server-side channel and the connected client has
* indicated that it supports re-authentication, otherwise false
*/
boolean connectedClientSupportsReauthentication() {
return authenticator.connectedClientSupportsReauthentication();
}
private void swapAuthenticatorsAndBeginReauthentication(ReauthenticationContext reauthenticationContext)
throws IOException {
// it is up to the new authenticator to close the old one
// replace with a new one and begin the process of re-authenticating
authenticator = authenticatorCreator.get();
authenticator.reauthenticate(reauthenticationContext);
}
} }

View File

@ -29,6 +29,7 @@ import java.io.Closeable;
import java.net.InetAddress; import java.net.InetAddress;
import java.nio.channels.SelectionKey; import java.nio.channels.SelectionKey;
import java.util.Map; import java.util.Map;
import java.util.function.Supplier;
public class PlaintextChannelBuilder implements ChannelBuilder { public class PlaintextChannelBuilder implements ChannelBuilder {
private static final Logger log = LoggerFactory.getLogger(PlaintextChannelBuilder.class); private static final Logger log = LoggerFactory.getLogger(PlaintextChannelBuilder.class);
@ -51,8 +52,8 @@ public class PlaintextChannelBuilder implements ChannelBuilder {
public KafkaChannel buildChannel(String id, SelectionKey key, int maxReceiveSize, MemoryPool memoryPool) throws KafkaException { public KafkaChannel buildChannel(String id, SelectionKey key, int maxReceiveSize, MemoryPool memoryPool) throws KafkaException {
try { try {
PlaintextTransportLayer transportLayer = new PlaintextTransportLayer(key); PlaintextTransportLayer transportLayer = new PlaintextTransportLayer(key);
PlaintextAuthenticator authenticator = new PlaintextAuthenticator(configs, transportLayer, listenerName); Supplier<Authenticator> authenticatorCreator = () -> new PlaintextAuthenticator(configs, transportLayer, listenerName);
return new KafkaChannel(id, transportLayer, authenticator, maxReceiveSize, return new KafkaChannel(id, transportLayer, authenticatorCreator, maxReceiveSize,
memoryPool != null ? memoryPool : MemoryPool.NONE); memoryPool != null ? memoryPool : MemoryPool.NONE);
} catch (Exception e) { } catch (Exception e) {
log.warn("Failed to create channel due to ", e); log.warn("Failed to create channel due to ", e);

View File

@ -0,0 +1,94 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.common.network;
import java.util.Objects;
/**
* Defines the context in which an {@link Authenticator} is to be created during
* a re-authentication.
*/
public class ReauthenticationContext {
private final NetworkReceive networkReceive;
private final Authenticator previousAuthenticator;
private final long reauthenticationBeginNanos;
/**
* Constructor
*
* @param previousAuthenticator
* the mandatory {@link Authenticator} that was previously used to
* authenticate the channel
* @param networkReceive
* the applicable {@link NetworkReceive} instance, if any. For the
* client side this may be a response that has been partially read, a
* non-null instance that has had no data read into it yet, or null;
* if it is non-null then this is the instance that data should
* initially be read into during re-authentication. For the server
* side this is mandatory and it must contain the
* {@code SaslHandshakeRequest} that has been received on the server
* and that initiates re-authentication.
*
* @param nowNanos
* the current time. The value is in nanoseconds as per
* {@code System.nanoTime()} and is therefore only useful when
* compared to such a value -- it's absolute value is meaningless.
* This defines the moment when re-authentication begins.
*/
public ReauthenticationContext(Authenticator previousAuthenticator, NetworkReceive networkReceive, long nowNanos) {
this.previousAuthenticator = Objects.requireNonNull(previousAuthenticator);
this.networkReceive = networkReceive;
this.reauthenticationBeginNanos = nowNanos;
}
/**
* Return the applicable {@link NetworkReceive} instance, if any. For the client
* side this may be a response that has been partially read, a non-null instance
* that has had no data read into it yet, or null; if it is non-null then this
* is the instance that data should initially be read into during
* re-authentication. For the server side this is mandatory and it must contain
* the {@code SaslHandshakeRequest} that has been received on the server and
* that initiates re-authentication.
*
* @return the applicable {@link NetworkReceive} instance, if any
*/
public NetworkReceive networkReceive() {
return networkReceive;
}
/**
* Return the always non-null {@link Authenticator} that was previously used to
* authenticate the channel
*
* @return the always non-null {@link Authenticator} that was previously used to
* authenticate the channel
*/
public Authenticator previousAuthenticator() {
return previousAuthenticator;
}
/**
* Return the time when re-authentication began. The value is in nanoseconds as
* per {@code System.nanoTime()} and is therefore only useful when compared to
* such a value -- it's absolute value is meaningless.
*
* @return the time when re-authentication began
*/
public long reauthenticationBeginNanos() {
return reauthenticationBeginNanos;
}
}

View File

@ -47,6 +47,7 @@ import org.apache.kafka.common.security.scram.internals.ScramServerCallbackHandl
import org.apache.kafka.common.security.ssl.SslFactory; import org.apache.kafka.common.security.ssl.SslFactory;
import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache; import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache;
import org.apache.kafka.common.utils.Java; import org.apache.kafka.common.utils.Java;
import org.apache.kafka.common.utils.Time;
import org.apache.kafka.common.utils.Utils; import org.apache.kafka.common.utils.Utils;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -62,6 +63,7 @@ import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.function.Supplier;
import javax.security.auth.Subject; import javax.security.auth.Subject;
@ -84,6 +86,8 @@ public class SaslChannelBuilder implements ChannelBuilder, ListenerReconfigurabl
private Map<String, ?> configs; private Map<String, ?> configs;
private KerberosShortNamer kerberosShortNamer; private KerberosShortNamer kerberosShortNamer;
private Map<String, AuthenticateCallbackHandler> saslCallbackHandlers; private Map<String, AuthenticateCallbackHandler> saslCallbackHandlers;
private Map<String, Long> connectionsMaxReauthMsByMechanism;
private final Time time;
public SaslChannelBuilder(Mode mode, public SaslChannelBuilder(Mode mode,
Map<String, JaasContext> jaasContexts, Map<String, JaasContext> jaasContexts,
@ -93,7 +97,8 @@ public class SaslChannelBuilder implements ChannelBuilder, ListenerReconfigurabl
String clientSaslMechanism, String clientSaslMechanism,
boolean handshakeRequestEnable, boolean handshakeRequestEnable,
CredentialCache credentialCache, CredentialCache credentialCache,
DelegationTokenCache tokenCache) { DelegationTokenCache tokenCache,
Time time) {
this.mode = mode; this.mode = mode;
this.jaasContexts = jaasContexts; this.jaasContexts = jaasContexts;
this.loginManagers = new HashMap<>(jaasContexts.size()); this.loginManagers = new HashMap<>(jaasContexts.size());
@ -106,6 +111,8 @@ public class SaslChannelBuilder implements ChannelBuilder, ListenerReconfigurabl
this.credentialCache = credentialCache; this.credentialCache = credentialCache;
this.tokenCache = tokenCache; this.tokenCache = tokenCache;
this.saslCallbackHandlers = new HashMap<>(); this.saslCallbackHandlers = new HashMap<>();
this.connectionsMaxReauthMsByMechanism = new HashMap<>();
this.time = time;
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@ -113,9 +120,10 @@ public class SaslChannelBuilder implements ChannelBuilder, ListenerReconfigurabl
public void configure(Map<String, ?> configs) throws KafkaException { public void configure(Map<String, ?> configs) throws KafkaException {
try { try {
this.configs = configs; this.configs = configs;
if (mode == Mode.SERVER) if (mode == Mode.SERVER) {
createServerCallbackHandlers(configs); createServerCallbackHandlers(configs);
else createConnectionsMaxReauthMsMap(configs);
} else
createClientCallbackHandler(configs); createClientCallbackHandler(configs);
for (Map.Entry<String, AuthenticateCallbackHandler> entry : saslCallbackHandlers.entrySet()) { for (Map.Entry<String, AuthenticateCallbackHandler> entry : saslCallbackHandlers.entrySet()) {
String mechanism = entry.getKey(); String mechanism = entry.getKey();
@ -130,7 +138,6 @@ public class SaslChannelBuilder implements ChannelBuilder, ListenerReconfigurabl
} catch (Exception ke) { } catch (Exception ke) {
defaultRealm = ""; defaultRealm = "";
} }
@SuppressWarnings("unchecked")
List<String> principalToLocalRules = (List<String>) configs.get(BrokerSecurityConfigs.SASL_KERBEROS_PRINCIPAL_TO_LOCAL_RULES_CONFIG); List<String> principalToLocalRules = (List<String>) configs.get(BrokerSecurityConfigs.SASL_KERBEROS_PRINCIPAL_TO_LOCAL_RULES_CONFIG);
if (principalToLocalRules != null) if (principalToLocalRules != null)
kerberosShortNamer = KerberosShortNamer.fromUnparsedRules(defaultRealm, principalToLocalRules); kerberosShortNamer = KerberosShortNamer.fromUnparsedRules(defaultRealm, principalToLocalRules);
@ -182,16 +189,17 @@ public class SaslChannelBuilder implements ChannelBuilder, ListenerReconfigurabl
SocketChannel socketChannel = (SocketChannel) key.channel(); SocketChannel socketChannel = (SocketChannel) key.channel();
Socket socket = socketChannel.socket(); Socket socket = socketChannel.socket();
TransportLayer transportLayer = buildTransportLayer(id, key, socketChannel); TransportLayer transportLayer = buildTransportLayer(id, key, socketChannel);
Authenticator authenticator; Supplier<Authenticator> authenticatorCreator;
if (mode == Mode.SERVER) { if (mode == Mode.SERVER) {
authenticator = buildServerAuthenticator(configs, authenticatorCreator = () -> buildServerAuthenticator(configs,
saslCallbackHandlers, Collections.unmodifiableMap(saslCallbackHandlers),
id, id,
transportLayer, transportLayer,
subjects); Collections.unmodifiableMap(subjects),
Collections.unmodifiableMap(connectionsMaxReauthMsByMechanism));
} else { } else {
LoginManager loginManager = loginManagers.get(clientSaslMechanism); LoginManager loginManager = loginManagers.get(clientSaslMechanism);
authenticator = buildClientAuthenticator(configs, authenticatorCreator = () -> buildClientAuthenticator(configs,
saslCallbackHandlers.get(clientSaslMechanism), saslCallbackHandlers.get(clientSaslMechanism),
id, id,
socket.getInetAddress().getHostName(), socket.getInetAddress().getHostName(),
@ -199,7 +207,7 @@ public class SaslChannelBuilder implements ChannelBuilder, ListenerReconfigurabl
transportLayer, transportLayer,
subjects.get(clientSaslMechanism)); subjects.get(clientSaslMechanism));
} }
return new KafkaChannel(id, transportLayer, authenticator, maxReceiveSize, memoryPool != null ? memoryPool : MemoryPool.NONE); return new KafkaChannel(id, transportLayer, authenticatorCreator, maxReceiveSize, memoryPool != null ? memoryPool : MemoryPool.NONE);
} catch (Exception e) { } catch (Exception e) {
log.info("Failed to create channel due to ", e); log.info("Failed to create channel due to ", e);
throw new KafkaException(e); throw new KafkaException(e);
@ -215,7 +223,8 @@ public class SaslChannelBuilder implements ChannelBuilder, ListenerReconfigurabl
handler.close(); handler.close();
} }
private TransportLayer buildTransportLayer(String id, SelectionKey key, SocketChannel socketChannel) throws IOException { // Visible to override for testing
protected TransportLayer buildTransportLayer(String id, SelectionKey key, SocketChannel socketChannel) throws IOException {
if (this.securityProtocol == SecurityProtocol.SASL_SSL) { if (this.securityProtocol == SecurityProtocol.SASL_SSL) {
return SslTransportLayer.create(id, key, return SslTransportLayer.create(id, key,
sslFactory.createSslEngine(socketChannel.socket().getInetAddress().getHostName(), socketChannel.socket().getPort())); sslFactory.createSslEngine(socketChannel.socket().getInetAddress().getHostName(), socketChannel.socket().getPort()));
@ -229,9 +238,10 @@ public class SaslChannelBuilder implements ChannelBuilder, ListenerReconfigurabl
Map<String, AuthenticateCallbackHandler> callbackHandlers, Map<String, AuthenticateCallbackHandler> callbackHandlers,
String id, String id,
TransportLayer transportLayer, TransportLayer transportLayer,
Map<String, Subject> subjects) throws IOException { Map<String, Subject> subjects,
Map<String, Long> connectionsMaxReauthMsByMechanism) {
return new SaslServerAuthenticator(configs, callbackHandlers, id, subjects, return new SaslServerAuthenticator(configs, callbackHandlers, id, subjects,
kerberosShortNamer, listenerName, securityProtocol, transportLayer); kerberosShortNamer, listenerName, securityProtocol, transportLayer, connectionsMaxReauthMsByMechanism, time);
} }
// Visible to override for testing // Visible to override for testing
@ -240,9 +250,9 @@ public class SaslChannelBuilder implements ChannelBuilder, ListenerReconfigurabl
String id, String id,
String serverHost, String serverHost,
String servicePrincipal, String servicePrincipal,
TransportLayer transportLayer, Subject subject) throws IOException { TransportLayer transportLayer, Subject subject) {
return new SaslClientAuthenticator(configs, callbackHandler, id, subject, servicePrincipal, return new SaslClientAuthenticator(configs, callbackHandler, id, subject, servicePrincipal,
serverHost, clientSaslMechanism, handshakeRequestEnable, transportLayer); serverHost, clientSaslMechanism, handshakeRequestEnable, transportLayer, time);
} }
// Package private for testing // Package private for testing
@ -272,6 +282,7 @@ public class SaslChannelBuilder implements ChannelBuilder, ListenerReconfigurabl
} }
private void createClientCallbackHandler(Map<String, ?> configs) { private void createClientCallbackHandler(Map<String, ?> configs) {
@SuppressWarnings("unchecked")
Class<? extends AuthenticateCallbackHandler> clazz = (Class<? extends AuthenticateCallbackHandler>) configs.get(SaslConfigs.SASL_CLIENT_CALLBACK_HANDLER_CLASS); Class<? extends AuthenticateCallbackHandler> clazz = (Class<? extends AuthenticateCallbackHandler>) configs.get(SaslConfigs.SASL_CLIENT_CALLBACK_HANDLER_CLASS);
if (clazz == null) if (clazz == null)
clazz = clientCallbackHandlerClass(); clazz = clientCallbackHandlerClass();
@ -283,6 +294,7 @@ public class SaslChannelBuilder implements ChannelBuilder, ListenerReconfigurabl
for (String mechanism : jaasContexts.keySet()) { for (String mechanism : jaasContexts.keySet()) {
AuthenticateCallbackHandler callbackHandler; AuthenticateCallbackHandler callbackHandler;
String prefix = ListenerName.saslMechanismPrefix(mechanism); String prefix = ListenerName.saslMechanismPrefix(mechanism);
@SuppressWarnings("unchecked")
Class<? extends AuthenticateCallbackHandler> clazz = Class<? extends AuthenticateCallbackHandler> clazz =
(Class<? extends AuthenticateCallbackHandler>) configs.get(prefix + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS); (Class<? extends AuthenticateCallbackHandler>) configs.get(prefix + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS);
if (clazz != null) if (clazz != null)
@ -299,6 +311,17 @@ public class SaslChannelBuilder implements ChannelBuilder, ListenerReconfigurabl
} }
} }
private void createConnectionsMaxReauthMsMap(Map<String, ?> configs) {
for (String mechanism : jaasContexts.keySet()) {
String prefix = ListenerName.saslMechanismPrefix(mechanism);
Long connectionsMaxReauthMs = (Long) configs.get(prefix + BrokerSecurityConfigs.CONNECTIONS_MAX_REAUTH_MS);
if (connectionsMaxReauthMs == null)
connectionsMaxReauthMs = (Long) configs.get(BrokerSecurityConfigs.CONNECTIONS_MAX_REAUTH_MS);
if (connectionsMaxReauthMs != null)
connectionsMaxReauthMsByMechanism.put(mechanism, connectionsMaxReauthMs);
}
}
private Class<? extends Login> defaultLoginClass(Map<String, ?> configs) { private Class<? extends Login> defaultLoginClass(Map<String, ?> configs) {
if (jaasContexts.containsKey(SaslConfigs.GSSAPI_MECHANISM)) if (jaasContexts.containsKey(SaslConfigs.GSSAPI_MECHANISM))
return KerberosLogin.class; return KerberosLogin.class;

View File

@ -27,6 +27,7 @@ import org.apache.kafka.common.metrics.stats.Count;
import org.apache.kafka.common.metrics.stats.Max; import org.apache.kafka.common.metrics.stats.Max;
import org.apache.kafka.common.metrics.stats.Meter; import org.apache.kafka.common.metrics.stats.Meter;
import org.apache.kafka.common.metrics.stats.SampledStat; import org.apache.kafka.common.metrics.stats.SampledStat;
import org.apache.kafka.common.metrics.stats.Total;
import org.apache.kafka.common.utils.LogContext; import org.apache.kafka.common.utils.LogContext;
import org.apache.kafka.common.utils.Time; import org.apache.kafka.common.utils.Time;
import org.slf4j.Logger; import org.slf4j.Logger;
@ -531,11 +532,31 @@ public class Selector implements Selectable, AutoCloseable {
try { try {
channel.prepare(); channel.prepare();
} catch (AuthenticationException e) { } catch (AuthenticationException e) {
sensors.failedAuthentication.record(); if (channel.successfulAuthentications() == 0)
sensors.failedAuthentication.record();
else
sensors.failedReauthentication.record();
throw e; throw e;
} }
if (channel.ready()) if (channel.ready()) {
sensors.successfulAuthentication.record(); long readyTimeMs = time.milliseconds();
if (channel.successfulAuthentications() == 1) {
sensors.successfulAuthentication.record(1.0, readyTimeMs);
if (!channel.connectedClientSupportsReauthentication())
sensors.successfulAuthenticationNoReauth.record(1.0, readyTimeMs);
} else {
sensors.successfulReauthentication.record(1.0, readyTimeMs);
if (channel.reauthenticationLatencyMs() == null)
log.warn(
"Should never happen: re-authentication latency for a re-authenticated channel was null; continuing...");
else
sensors.reauthenticationLatency
.record(channel.reauthenticationLatencyMs().doubleValue(), readyTimeMs);
}
}
List<NetworkReceive> responsesReceivedDuringReauthentication = channel
.getAndClearResponsesReceivedDuringReauthentication();
responsesReceivedDuringReauthentication.forEach(receive -> addToStagedReceives(channel, receive));
} }
attemptRead(key, channel); attemptRead(key, channel);
@ -551,7 +572,8 @@ public class Selector implements Selectable, AutoCloseable {
} }
/* if channel is ready write to any sockets that have space in their buffer and for which we have data */ /* if channel is ready write to any sockets that have space in their buffer and for which we have data */
if (channel.ready() && key.isWritable()) { if (channel.ready() && key.isWritable() && !channel.maybeBeginClientReauthentication(
() -> channelStartTimeNanos != 0 ? channelStartTimeNanos : currentTimeNanos)) {
Send send; Send send;
try { try {
send = channel.write(); send = channel.write();
@ -970,7 +992,11 @@ public class Selector implements Selectable, AutoCloseable {
public final Sensor connectionClosed; public final Sensor connectionClosed;
public final Sensor connectionCreated; public final Sensor connectionCreated;
public final Sensor successfulAuthentication; public final Sensor successfulAuthentication;
public final Sensor successfulReauthentication;
public final Sensor successfulAuthenticationNoReauth;
public final Sensor reauthenticationLatency;
public final Sensor failedAuthentication; public final Sensor failedAuthentication;
public final Sensor failedReauthentication;
public final Sensor bytesTransferred; public final Sensor bytesTransferred;
public final Sensor bytesSent; public final Sensor bytesSent;
public final Sensor bytesReceived; public final Sensor bytesReceived;
@ -1007,10 +1033,35 @@ public class Selector implements Selectable, AutoCloseable {
this.successfulAuthentication.add(createMeter(metrics, metricGrpName, metricTags, this.successfulAuthentication.add(createMeter(metrics, metricGrpName, metricTags,
"successful-authentication", "connections with successful authentication")); "successful-authentication", "connections with successful authentication"));
this.successfulReauthentication = sensor("successful-reauthentication:" + tagsSuffix);
this.successfulReauthentication.add(createMeter(metrics, metricGrpName, metricTags,
"successful-reauthentication", "successful re-authentication of connections"));
this.successfulAuthenticationNoReauth = sensor("successful-authentication-no-reauth:" + tagsSuffix);
MetricName successfulAuthenticationNoReauthMetricName = metrics.metricName(
"successful-authentication-no-reauth-total", metricGrpName,
"The total number of connections with successful authentication where the client does not support re-authentication",
metricTags);
this.successfulAuthenticationNoReauth.add(successfulAuthenticationNoReauthMetricName, new Total());
this.failedAuthentication = sensor("failed-authentication:" + tagsSuffix); this.failedAuthentication = sensor("failed-authentication:" + tagsSuffix);
this.failedAuthentication.add(createMeter(metrics, metricGrpName, metricTags, this.failedAuthentication.add(createMeter(metrics, metricGrpName, metricTags,
"failed-authentication", "connections with failed authentication")); "failed-authentication", "connections with failed authentication"));
this.failedReauthentication = sensor("failed-reauthentication:" + tagsSuffix);
this.failedReauthentication.add(createMeter(metrics, metricGrpName, metricTags,
"failed-reauthentication", "failed re-authentication of connections"));
this.reauthenticationLatency = sensor("reauthentication-latency:" + tagsSuffix);
MetricName reauthenticationLatencyMaxMetricName = metrics.metricName("reauthentication-latency-max",
metricGrpName, "The max latency observed due to re-authentication",
metricTags);
this.reauthenticationLatency.add(reauthenticationLatencyMaxMetricName, new Max());
MetricName reauthenticationLatencyAvgMetricName = metrics.metricName("reauthentication-latency-avg",
metricGrpName, "The average latency observed due to re-authentication",
metricTags);
this.reauthenticationLatency.add(reauthenticationLatencyAvgMetricName, new Avg());
this.bytesTransferred = sensor("bytes-sent-received:" + tagsSuffix); this.bytesTransferred = sensor("bytes-sent-received:" + tagsSuffix);
bytesTransferred.add(createMeter(metrics, metricGrpName, metricTags, new Count(), bytesTransferred.add(createMeter(metrics, metricGrpName, metricTags, new Count(),
"network-io", "network operations (reads or writes) on all connections")); "network-io", "network operations (reads or writes) on all connections"));

View File

@ -38,6 +38,7 @@ import java.nio.channels.SocketChannel;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.function.Supplier;
public class SslChannelBuilder implements ChannelBuilder, ListenerReconfigurable { public class SslChannelBuilder implements ChannelBuilder, ListenerReconfigurable {
private static final Logger log = LoggerFactory.getLogger(SslChannelBuilder.class); private static final Logger log = LoggerFactory.getLogger(SslChannelBuilder.class);
@ -97,8 +98,8 @@ public class SslChannelBuilder implements ChannelBuilder, ListenerReconfigurable
public KafkaChannel buildChannel(String id, SelectionKey key, int maxReceiveSize, MemoryPool memoryPool) throws KafkaException { public KafkaChannel buildChannel(String id, SelectionKey key, int maxReceiveSize, MemoryPool memoryPool) throws KafkaException {
try { try {
SslTransportLayer transportLayer = buildTransportLayer(sslFactory, id, key, peerHost(key)); SslTransportLayer transportLayer = buildTransportLayer(sslFactory, id, key, peerHost(key));
Authenticator authenticator = new SslAuthenticator(configs, transportLayer, listenerName, sslPrincipalMapper); Supplier<Authenticator> authenticatorCreator = () -> new SslAuthenticator(configs, transportLayer, listenerName, sslPrincipalMapper);
return new KafkaChannel(id, transportLayer, authenticator, maxReceiveSize, return new KafkaChannel(id, transportLayer, authenticatorCreator, maxReceiveSize,
memoryPool != null ? memoryPool : MemoryPool.NONE); memoryPool != null ? memoryPool : MemoryPool.NONE);
} catch (Exception e) { } catch (Exception e) {
log.info("Failed to create channel due to ", e); log.info("Failed to create channel due to ", e);

View File

@ -40,8 +40,11 @@ public class SaslAuthenticateRequest extends AbstractRequest {
private static final Schema SASL_AUTHENTICATE_REQUEST_V0 = new Schema( private static final Schema SASL_AUTHENTICATE_REQUEST_V0 = new Schema(
new Field(SASL_AUTH_BYTES_KEY_NAME, BYTES, "SASL authentication bytes from client as defined by the SASL mechanism.")); new Field(SASL_AUTH_BYTES_KEY_NAME, BYTES, "SASL authentication bytes from client as defined by the SASL mechanism."));
/* v1 request is the same as v0; session_lifetime_ms has been added to the response */
private static final Schema SASL_AUTHENTICATE_REQUEST_V1 = SASL_AUTHENTICATE_REQUEST_V0;
public static Schema[] schemaVersions() { public static Schema[] schemaVersions() {
return new Schema[]{SASL_AUTHENTICATE_REQUEST_V0}; return new Schema[]{SASL_AUTHENTICATE_REQUEST_V0, SASL_AUTHENTICATE_REQUEST_V1};
} }
private final ByteBuffer saslAuthBytes; private final ByteBuffer saslAuthBytes;
@ -90,6 +93,7 @@ public class SaslAuthenticateRequest extends AbstractRequest {
short versionId = version(); short versionId = version();
switch (versionId) { switch (versionId) {
case 0: case 0:
case 1:
return new SaslAuthenticateResponse(Errors.forException(e), e.getMessage()); return new SaslAuthenticateResponse(Errors.forException(e), e.getMessage());
default: default:
throw new IllegalArgumentException(String.format("Version %d is not valid. Valid versions for %s are 0 to %d", throw new IllegalArgumentException(String.format("Version %d is not valid. Valid versions for %s are 0 to %d",

View File

@ -28,6 +28,7 @@ import java.util.Map;
import static org.apache.kafka.common.protocol.CommonFields.ERROR_CODE; import static org.apache.kafka.common.protocol.CommonFields.ERROR_CODE;
import static org.apache.kafka.common.protocol.CommonFields.ERROR_MESSAGE; import static org.apache.kafka.common.protocol.CommonFields.ERROR_MESSAGE;
import static org.apache.kafka.common.protocol.types.Type.BYTES; import static org.apache.kafka.common.protocol.types.Type.BYTES;
import static org.apache.kafka.common.protocol.types.Type.INT64;
/** /**
@ -36,14 +37,21 @@ import static org.apache.kafka.common.protocol.types.Type.BYTES;
*/ */
public class SaslAuthenticateResponse extends AbstractResponse { public class SaslAuthenticateResponse extends AbstractResponse {
private static final String SASL_AUTH_BYTES_KEY_NAME = "sasl_auth_bytes"; private static final String SASL_AUTH_BYTES_KEY_NAME = "sasl_auth_bytes";
private static final String SESSION_LIFETIME_MS = "session_lifetime_ms";
private static final Schema SASL_AUTHENTICATE_RESPONSE_V0 = new Schema( private static final Schema SASL_AUTHENTICATE_RESPONSE_V0 = new Schema(
ERROR_CODE, ERROR_CODE,
ERROR_MESSAGE, ERROR_MESSAGE,
new Field(SASL_AUTH_BYTES_KEY_NAME, BYTES, "SASL authentication bytes from server as defined by the SASL mechanism.")); new Field(SASL_AUTH_BYTES_KEY_NAME, BYTES, "SASL authentication bytes from server as defined by the SASL mechanism."));
private static final Schema SASL_AUTHENTICATE_RESPONSE_V1 = new Schema(
ERROR_CODE,
ERROR_MESSAGE,
new Field(SASL_AUTH_BYTES_KEY_NAME, BYTES, "SASL authentication bytes from server as defined by the SASL mechanism."),
new Field(SESSION_LIFETIME_MS, INT64, "Number of milliseconds after which only re-authentication over the existing connection to create a new session can occur."));
public static Schema[] schemaVersions() { public static Schema[] schemaVersions() {
return new Schema[]{SASL_AUTHENTICATE_RESPONSE_V0}; return new Schema[]{SASL_AUTHENTICATE_RESPONSE_V0, SASL_AUTHENTICATE_RESPONSE_V1};
} }
private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocate(0); private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocate(0);
@ -56,21 +64,28 @@ public class SaslAuthenticateResponse extends AbstractResponse {
*/ */
private final Errors error; private final Errors error;
private final String errorMessage; private final String errorMessage;
private final long sessionLifetimeMs;
public SaslAuthenticateResponse(Errors error, String errorMessage) { public SaslAuthenticateResponse(Errors error, String errorMessage) {
this(error, errorMessage, EMPTY_BUFFER); this(error, errorMessage, EMPTY_BUFFER);
} }
public SaslAuthenticateResponse(Errors error, String errorMessage, ByteBuffer saslAuthBytes) { public SaslAuthenticateResponse(Errors error, String errorMessage, ByteBuffer saslAuthBytes) {
this(error, errorMessage, saslAuthBytes, 0L);
}
public SaslAuthenticateResponse(Errors error, String errorMessage, ByteBuffer saslAuthBytes, long sessionLifetimeMs) {
this.error = error; this.error = error;
this.errorMessage = errorMessage; this.errorMessage = errorMessage;
this.saslAuthBytes = saslAuthBytes; this.saslAuthBytes = saslAuthBytes;
this.sessionLifetimeMs = sessionLifetimeMs;
} }
public SaslAuthenticateResponse(Struct struct) { public SaslAuthenticateResponse(Struct struct) {
error = Errors.forCode(struct.get(ERROR_CODE)); error = Errors.forCode(struct.get(ERROR_CODE));
errorMessage = struct.get(ERROR_MESSAGE); errorMessage = struct.get(ERROR_MESSAGE);
saslAuthBytes = struct.getBytes(SASL_AUTH_BYTES_KEY_NAME); saslAuthBytes = struct.getBytes(SASL_AUTH_BYTES_KEY_NAME);
sessionLifetimeMs = struct.hasField(SESSION_LIFETIME_MS) ? struct.getLong(SESSION_LIFETIME_MS).longValue() : 0L;
} }
public Errors error() { public Errors error() {
@ -85,6 +100,10 @@ public class SaslAuthenticateResponse extends AbstractResponse {
return saslAuthBytes; return saslAuthBytes;
} }
public long sessionLifetimeMs() {
return sessionLifetimeMs;
}
@Override @Override
public Map<Errors, Integer> errorCounts() { public Map<Errors, Integer> errorCounts() {
return errorCounts(error); return errorCounts(error);
@ -96,6 +115,8 @@ public class SaslAuthenticateResponse extends AbstractResponse {
struct.set(ERROR_CODE, error.code()); struct.set(ERROR_CODE, error.code());
struct.set(ERROR_MESSAGE, errorMessage); struct.set(ERROR_MESSAGE, errorMessage);
struct.set(SASL_AUTH_BYTES_KEY_NAME, saslAuthBytes); struct.set(SASL_AUTH_BYTES_KEY_NAME, saslAuthBytes);
if (version > 0)
struct.set(SESSION_LIFETIME_MS, sessionLifetimeMs);
return struct; return struct;
} }

View File

@ -25,6 +25,7 @@ import org.apache.kafka.common.errors.SaslAuthenticationException;
import org.apache.kafka.common.errors.UnsupportedSaslMechanismException; import org.apache.kafka.common.errors.UnsupportedSaslMechanismException;
import org.apache.kafka.common.network.Authenticator; import org.apache.kafka.common.network.Authenticator;
import org.apache.kafka.common.network.NetworkSend; import org.apache.kafka.common.network.NetworkSend;
import org.apache.kafka.common.network.ReauthenticationContext;
import org.apache.kafka.common.network.NetworkReceive; import org.apache.kafka.common.network.NetworkReceive;
import org.apache.kafka.common.network.Send; import org.apache.kafka.common.network.Send;
import org.apache.kafka.common.network.TransportLayer; import org.apache.kafka.common.network.TransportLayer;
@ -43,6 +44,7 @@ import org.apache.kafka.common.requests.SaslHandshakeResponse;
import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
import org.apache.kafka.common.security.auth.KafkaPrincipal; import org.apache.kafka.common.security.auth.KafkaPrincipal;
import org.apache.kafka.common.security.kerberos.KerberosError; import org.apache.kafka.common.security.kerberos.KerberosError;
import org.apache.kafka.common.utils.Time;
import org.apache.kafka.common.utils.Utils; import org.apache.kafka.common.utils.Utils;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -57,27 +59,49 @@ import java.nio.channels.SelectionKey;
import java.security.Principal; import java.security.Principal;
import java.security.PrivilegedActionException; import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction; import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator; import java.util.Iterator;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.Random;
import java.util.Set; import java.util.Set;
public class SaslClientAuthenticator implements Authenticator { public class SaslClientAuthenticator implements Authenticator {
/**
* The internal state transitions for initial authentication of a channel are
* declared in order, starting with {@link #SEND_APIVERSIONS_REQUEST} and ending
* in either {@link #COMPLETE} or {@link #FAILED}.
* <p>
* Re-authentication of a channel starts with the state
* {@link #REAUTH_PROCESS_ORIG_APIVERSIONS_RESPONSE} and then flows to
* {@link #REAUTH_SEND_HANDSHAKE_REQUEST} followed by
* {@link #REAUTH_RECEIVE_HANDSHAKE_OR_OTHER_RESPONSE} and then
* {@value #REAUTH_INITIAL}; after that the flow joins the authentication flow
* at the {@link #INTERMEDIATE} state and ends at either {@link #COMPLETE} or
* {@link #FAILED}.
*/
public enum SaslState { public enum SaslState {
SEND_APIVERSIONS_REQUEST, // Initial state: client sends ApiVersionsRequest in this state SEND_APIVERSIONS_REQUEST, // Initial state for authentication: client sends ApiVersionsRequest in this state when authenticating
RECEIVE_APIVERSIONS_RESPONSE, // Awaiting ApiVersionsResponse from server RECEIVE_APIVERSIONS_RESPONSE, // Awaiting ApiVersionsResponse from server
SEND_HANDSHAKE_REQUEST, // Received ApiVersionsResponse, send SaslHandshake request SEND_HANDSHAKE_REQUEST, // Received ApiVersionsResponse, send SaslHandshake request
RECEIVE_HANDSHAKE_RESPONSE, // Awaiting SaslHandshake request from server RECEIVE_HANDSHAKE_RESPONSE, // Awaiting SaslHandshake response from server when authenticating
INITIAL, // Initial state starting SASL token exchange for configured mechanism, send first token INITIAL, // Initial authentication state starting SASL token exchange for configured mechanism, send first token
INTERMEDIATE, // Intermediate state during SASL token exchange, process challenges and send responses INTERMEDIATE, // Intermediate state during SASL token exchange, process challenges and send responses
CLIENT_COMPLETE, // Sent response to last challenge. If using SaslAuthenticate, wait for authentication status from server, else COMPLETE CLIENT_COMPLETE, // Sent response to last challenge. If using SaslAuthenticate, wait for authentication status from server, else COMPLETE
COMPLETE, // Authentication sequence complete. If using SaslAuthenticate, this state implies successful authentication. COMPLETE, // Authentication sequence complete. If using SaslAuthenticate, this state implies successful authentication.
FAILED // Failed authentication due to an error at some stage FAILED, // Failed authentication due to an error at some stage
REAUTH_PROCESS_ORIG_APIVERSIONS_RESPONSE, // Initial state for re-authentication: process ApiVersionsResponse from original authentication
REAUTH_SEND_HANDSHAKE_REQUEST, // Processed original ApiVersionsResponse, send SaslHandshake request as part of re-authentication
REAUTH_RECEIVE_HANDSHAKE_OR_OTHER_RESPONSE, // Awaiting SaslHandshake response from server when re-authenticating, and may receive other, in-flight responses sent prior to start of re-authentication as well
REAUTH_INITIAL, // Initial re-authentication state starting SASL token exchange for configured mechanism, send first token
} }
private static final Logger LOG = LoggerFactory.getLogger(SaslClientAuthenticator.class); private static final Logger LOG = LoggerFactory.getLogger(SaslClientAuthenticator.class);
private static final short DISABLE_KAFKA_SASL_AUTHENTICATE_HEADER = -1; private static final short DISABLE_KAFKA_SASL_AUTHENTICATE_HEADER = -1;
private static final Random RNG = new Random();
private final Subject subject; private final Subject subject;
private final String servicePrincipal; private final String servicePrincipal;
@ -89,6 +113,8 @@ public class SaslClientAuthenticator implements Authenticator {
private final Map<String, ?> configs; private final Map<String, ?> configs;
private final String clientPrincipalName; private final String clientPrincipalName;
private final AuthenticateCallbackHandler callbackHandler; private final AuthenticateCallbackHandler callbackHandler;
private final Time time;
private final ReauthInfo reauthInfo;
// buffers used in `authenticate` // buffers used in `authenticate`
private NetworkReceive netInBuffer; private NetworkReceive netInBuffer;
@ -113,7 +139,8 @@ public class SaslClientAuthenticator implements Authenticator {
String host, String host,
String mechanism, String mechanism,
boolean handshakeRequestEnable, boolean handshakeRequestEnable,
TransportLayer transportLayer) { TransportLayer transportLayer,
Time time) {
this.node = node; this.node = node;
this.subject = subject; this.subject = subject;
this.callbackHandler = callbackHandler; this.callbackHandler = callbackHandler;
@ -124,6 +151,8 @@ public class SaslClientAuthenticator implements Authenticator {
this.transportLayer = transportLayer; this.transportLayer = transportLayer;
this.configs = configs; this.configs = configs;
this.saslAuthenticateVersion = DISABLE_KAFKA_SASL_AUTHENTICATE_HEADER; this.saslAuthenticateVersion = DISABLE_KAFKA_SASL_AUTHENTICATE_HEADER;
this.time = time;
this.reauthInfo = new ReauthInfo();
try { try {
setSaslState(handshakeRequestEnable ? SaslState.SEND_APIVERSIONS_REQUEST : SaslState.INITIAL); setSaslState(handshakeRequestEnable ? SaslState.SEND_APIVERSIONS_REQUEST : SaslState.INITIAL);
@ -163,7 +192,6 @@ public class SaslClientAuthenticator implements Authenticator {
* followed by N bytes representing the opaque payload. * followed by N bytes representing the opaque payload.
*/ */
public void authenticate() throws IOException { public void authenticate() throws IOException {
short saslHandshakeVersion = 0;
if (netOutBuffer != null && !flushNetOutBufferAndUpdateInterestOps()) if (netOutBuffer != null && !flushNetOutBufferAndUpdateInterestOps())
return; return;
@ -179,16 +207,13 @@ public class SaslClientAuthenticator implements Authenticator {
if (apiVersionsResponse == null) if (apiVersionsResponse == null)
break; break;
else { else {
saslHandshakeVersion = apiVersionsResponse.apiVersion(ApiKeys.SASL_HANDSHAKE.id).maxVersion; saslAuthenticateVersion(apiVersionsResponse);
ApiVersion authenticateVersion = apiVersionsResponse.apiVersion(ApiKeys.SASL_AUTHENTICATE.id); reauthInfo.apiVersionsResponseReceivedFromBroker = apiVersionsResponse;
if (authenticateVersion != null)
saslAuthenticateVersion((short) Math.min(authenticateVersion.maxVersion, ApiKeys.SASL_AUTHENTICATE.latestVersion()));
setSaslState(SaslState.SEND_HANDSHAKE_REQUEST); setSaslState(SaslState.SEND_HANDSHAKE_REQUEST);
// Fall through to send handshake request with the latest supported version // Fall through to send handshake request with the latest supported version
} }
case SEND_HANDSHAKE_REQUEST: case SEND_HANDSHAKE_REQUEST:
SaslHandshakeRequest handshakeRequest = createSaslHandshakeRequest(saslHandshakeVersion); sendHandshakeRequest(reauthInfo.apiVersionsResponseReceivedFromBroker);
send(handshakeRequest.toSend(node, nextRequestHeader(ApiKeys.SASL_HANDSHAKE, handshakeRequest.version())));
setSaslState(SaslState.RECEIVE_HANDSHAKE_RESPONSE); setSaslState(SaslState.RECEIVE_HANDSHAKE_RESPONSE);
break; break;
case RECEIVE_HANDSHAKE_RESPONSE: case RECEIVE_HANDSHAKE_RESPONSE:
@ -201,7 +226,32 @@ public class SaslClientAuthenticator implements Authenticator {
// Fall through and start SASL authentication using the configured client mechanism // Fall through and start SASL authentication using the configured client mechanism
} }
case INITIAL: case INITIAL:
sendSaslClientToken(new byte[0], true); sendInitialToken();
setSaslState(SaslState.INTERMEDIATE);
break;
case REAUTH_PROCESS_ORIG_APIVERSIONS_RESPONSE:
saslAuthenticateVersion(reauthInfo.apiVersionsResponseFromOriginalAuthentication);
setSaslState(SaslState.REAUTH_SEND_HANDSHAKE_REQUEST); // Will set immediately
// Fall through to send handshake request with the latest supported version
case REAUTH_SEND_HANDSHAKE_REQUEST:
sendHandshakeRequest(reauthInfo.apiVersionsResponseFromOriginalAuthentication);
setSaslState(SaslState.REAUTH_RECEIVE_HANDSHAKE_OR_OTHER_RESPONSE);
break;
case REAUTH_RECEIVE_HANDSHAKE_OR_OTHER_RESPONSE:
handshakeResponse = (SaslHandshakeResponse) receiveKafkaResponse();
if (handshakeResponse == null)
break;
handleSaslHandshakeResponse(handshakeResponse);
setSaslState(SaslState.REAUTH_INITIAL); // Will set immediately
/*
* Fall through and start SASL authentication using the configured client
* mechanism. Note that we have to either fall through or add a loop to enter
* the switch statement again. We will fall through to avoid adding the loop and
* therefore minimize the changes to authentication-related code due to the
* changes related to re-authentication.
*/
case REAUTH_INITIAL:
sendInitialToken();
setSaslState(SaslState.INTERMEDIATE); setSaslState(SaslState.INTERMEDIATE);
break; break;
case INTERMEDIATE: case INTERMEDIATE:
@ -229,6 +279,46 @@ public class SaslClientAuthenticator implements Authenticator {
} }
} }
private void sendHandshakeRequest(ApiVersionsResponse apiVersionsResponse) throws IOException {
SaslHandshakeRequest handshakeRequest = createSaslHandshakeRequest(
apiVersionsResponse.apiVersion(ApiKeys.SASL_HANDSHAKE.id).maxVersion);
send(handshakeRequest.toSend(node, nextRequestHeader(ApiKeys.SASL_HANDSHAKE, handshakeRequest.version())));
}
private void sendInitialToken() throws IOException {
sendSaslClientToken(new byte[0], true);
}
@Override
public void reauthenticate(ReauthenticationContext reauthenticationContext) throws IOException {
SaslClientAuthenticator previousSaslClientAuthenticator = (SaslClientAuthenticator) Objects
.requireNonNull(reauthenticationContext).previousAuthenticator();
ApiVersionsResponse apiVersionsResponseFromOriginalAuthentication = previousSaslClientAuthenticator.reauthInfo
.apiVersionsResponse();
previousSaslClientAuthenticator.close();
reauthInfo.reauthenticating(apiVersionsResponseFromOriginalAuthentication,
reauthenticationContext.reauthenticationBeginNanos());
NetworkReceive netInBufferFromChannel = reauthenticationContext.networkReceive();
netInBuffer = netInBufferFromChannel;
setSaslState(SaslState.REAUTH_PROCESS_ORIG_APIVERSIONS_RESPONSE); // Will set immediately
authenticate();
}
@Override
public List<NetworkReceive> getAndClearResponsesReceivedDuringReauthentication() {
return reauthInfo.getAndClearResponsesReceivedDuringReauthentication();
}
@Override
public Long clientSessionReauthenticationTimeNanos() {
return reauthInfo.clientSessionReauthenticationTimeNanos;
}
@Override
public Long reauthenticationLatencyMs() {
return reauthInfo.reauthenticationLatencyMs();
}
private RequestHeader nextRequestHeader(ApiKeys apiKey, short version) { private RequestHeader nextRequestHeader(ApiKeys apiKey, short version) {
String clientId = (String) configs.get(CommonClientConfigs.CLIENT_ID_CONFIG); String clientId = (String) configs.get(CommonClientConfigs.CLIENT_ID_CONFIG);
currentRequestHeader = new RequestHeader(apiKey, version, clientId, correlationId++); currentRequestHeader = new RequestHeader(apiKey, version, clientId, correlationId++);
@ -241,8 +331,11 @@ public class SaslClientAuthenticator implements Authenticator {
} }
// Visible to override for testing // Visible to override for testing
protected void saslAuthenticateVersion(short version) { protected void saslAuthenticateVersion(ApiVersionsResponse apiVersionsResponse) {
this.saslAuthenticateVersion = version; ApiVersion authenticateVersion = apiVersionsResponse.apiVersion(ApiKeys.SASL_AUTHENTICATE.id);
if (authenticateVersion != null)
this.saslAuthenticateVersion = (short) Math.min(authenticateVersion.maxVersion,
ApiKeys.SASL_AUTHENTICATE.latestVersion());
} }
private void setSaslState(SaslState saslState) { private void setSaslState(SaslState saslState) {
@ -252,8 +345,17 @@ public class SaslClientAuthenticator implements Authenticator {
this.pendingSaslState = null; this.pendingSaslState = null;
this.saslState = saslState; this.saslState = saslState;
LOG.debug("Set SASL client state to {}", saslState); LOG.debug("Set SASL client state to {}", saslState);
if (saslState == SaslState.COMPLETE) if (saslState == SaslState.COMPLETE) {
transportLayer.removeInterestOps(SelectionKey.OP_WRITE); reauthInfo.setAuthenticationEndAndSessionReauthenticationTimes(time.nanoseconds());
if (!reauthInfo.reauthenticating())
transportLayer.removeInterestOps(SelectionKey.OP_WRITE);
else
/*
* Re-authentication is triggered by a write, so we have to make sure that
* pending write is actually sent.
*/
transportLayer.addInterestOps(SelectionKey.OP_WRITE);
}
} }
} }
@ -337,6 +439,9 @@ public class SaslClientAuthenticator implements Authenticator {
String errMsg = response.errorMessage(); String errMsg = response.errorMessage();
throw errMsg == null ? error.exception() : error.exception(errMsg); throw errMsg == null ? error.exception() : error.exception(errMsg);
} }
long sessionLifetimeMs = response.sessionLifetimeMs();
if (sessionLifetimeMs > 0L)
reauthInfo.positiveSessionLifetimeMs = sessionLifetimeMs;
return Utils.readBytes(response.saslAuthBytes()); return Utils.readBytes(response.saslAuthBytes());
} else } else
return null; return null;
@ -384,6 +489,9 @@ public class SaslClientAuthenticator implements Authenticator {
} }
private AbstractResponse receiveKafkaResponse() throws IOException { private AbstractResponse receiveKafkaResponse() throws IOException {
if (netInBuffer == null)
netInBuffer = new NetworkReceive(node);
NetworkReceive receive = netInBuffer;
try { try {
byte[] responseBytes = receiveResponseOrToken(); byte[] responseBytes = receiveResponseOrToken();
if (responseBytes == null) if (responseBytes == null)
@ -394,6 +502,19 @@ public class SaslClientAuthenticator implements Authenticator {
return response; return response;
} }
} catch (SchemaException | IllegalArgumentException e) { } catch (SchemaException | IllegalArgumentException e) {
/*
* Account for the fact that during re-authentication there may be responses
* arriving for requests that were sent in the past.
*/
if (reauthInfo.reauthenticating()) {
/*
* It didn't match the current request header, so it must be unrelated to
* re-authentication. Save it so it can be processed later.
*/
receive.payload().rewind();
reauthInfo.pendingAuthenticatedReceives.add(receive);
return null;
}
LOG.debug("Invalid SASL mechanism response, server may be expecting only GSSAPI tokens"); LOG.debug("Invalid SASL mechanism response, server may be expecting only GSSAPI tokens");
setSaslState(SaslState.FAILED); setSaslState(SaslState.FAILED);
throw new IllegalSaslStateException("Invalid SASL mechanism response, server may be expecting a different protocol", e); throw new IllegalSaslStateException("Invalid SASL mechanism response, server may be expecting a different protocol", e);
@ -436,4 +557,81 @@ public class SaslClientAuthenticator implements Authenticator {
} }
} }
/**
* Information related to re-authentication
*/
private static class ReauthInfo {
public ApiVersionsResponse apiVersionsResponseFromOriginalAuthentication;
public long reauthenticationBeginNanos;
public List<NetworkReceive> pendingAuthenticatedReceives = new ArrayList<>();
public ApiVersionsResponse apiVersionsResponseReceivedFromBroker;
public Long positiveSessionLifetimeMs;
public long authenticationEndNanos;
public Long clientSessionReauthenticationTimeNanos;
public void reauthenticating(ApiVersionsResponse apiVersionsResponseFromOriginalAuthentication,
long reauthenticationBeginNanos) {
this.apiVersionsResponseFromOriginalAuthentication = Objects
.requireNonNull(apiVersionsResponseFromOriginalAuthentication);
this.reauthenticationBeginNanos = reauthenticationBeginNanos;
}
public boolean reauthenticating() {
return apiVersionsResponseFromOriginalAuthentication != null;
}
public ApiVersionsResponse apiVersionsResponse() {
return reauthenticating() ? apiVersionsResponseFromOriginalAuthentication
: apiVersionsResponseReceivedFromBroker;
}
/**
* Return the (always non-null but possibly empty) NetworkReceive responses that
* arrived during re-authentication that are unrelated to re-authentication, if
* any. These correspond to requests sent prior to the beginning of
* re-authentication; the requests were made when the channel was successfully
* authenticated, and the responses arrived during the re-authentication
* process.
*
* @return the (always non-null but possibly empty) NetworkReceive responses
* that arrived during re-authentication that are unrelated to
* re-authentication, if any
*/
public List<NetworkReceive> getAndClearResponsesReceivedDuringReauthentication() {
if (pendingAuthenticatedReceives.isEmpty())
return Collections.emptyList();
List<NetworkReceive> retval = pendingAuthenticatedReceives;
pendingAuthenticatedReceives = new ArrayList<>();
return retval;
}
public void setAuthenticationEndAndSessionReauthenticationTimes(long nowNanos) {
authenticationEndNanos = nowNanos;
long sessionLifetimeMsToUse = 0;
if (positiveSessionLifetimeMs != null) {
// pick a random percentage between 85% and 95% for session re-authentication
double pctWindowFactorToTakeNetworkLatencyAndClockDriftIntoAccount = 0.85;
double pctWindowJitterToAvoidReauthenticationStormAcrossManyChannelsSimultaneously = 0.10;
double pctToUse = pctWindowFactorToTakeNetworkLatencyAndClockDriftIntoAccount + RNG.nextDouble()
* pctWindowJitterToAvoidReauthenticationStormAcrossManyChannelsSimultaneously;
sessionLifetimeMsToUse = (long) (positiveSessionLifetimeMs.longValue() * pctToUse);
clientSessionReauthenticationTimeNanos = authenticationEndNanos + 1000 * 1000 * sessionLifetimeMsToUse;
LOG.debug(
"Finished {} with session expiration in {} ms and session re-authentication on or after {} ms",
authenticationOrReauthenticationText(), positiveSessionLifetimeMs, sessionLifetimeMsToUse);
} else
LOG.debug("Finished {} with no session expiration and no session re-authentication",
authenticationOrReauthenticationText());
}
public Long reauthenticationLatencyMs() {
return reauthenticating()
? Long.valueOf(Math.round((authenticationEndNanos - reauthenticationBeginNanos) / 1000.0 / 1000.0))
: null;
}
private String authenticationOrReauthenticationText() {
return reauthenticating() ? "re-authentication" : "authentication";
}
}
} }

View File

@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.common.security.authenticator;
import org.apache.kafka.common.config.internals.BrokerSecurityConfigs;
public class SaslInternalConfigs {
/**
* The server (broker) specifies a positive session length in milliseconds to a
* SASL client when {@link BrokerSecurityConfigs#CONNECTIONS_MAX_REAUTH_MS} is
* positive as per <a href=
* "https://cwiki.apache.org/confluence/display/KAFKA/KIP-368%3A+Allow+SASL+Connections+to+Periodically+Re-Authenticate">KIP
* 368: Allow SASL Connections to Periodically Re-Authenticate</a>. The session
* length is the minimum of the configured value and any session length implied
* by the credential presented during authentication. The lifetime defined by
* the credential, in terms of milliseconds since the epoch, is available via a
* negotiated property on the SASL Server instance, and that value can be
* converted to a session length by subtracting the time at which authentication
* occurred. This variable defines the negotiated property key that is used to
* communicate the credential lifetime in milliseconds since the epoch.
*/
public static final String CREDENTIAL_LIFETIME_MS_SASL_NEGOTIATED_PROPERTY_KEY = "CREDENTIAL.LIFETIME.MS";
private SaslInternalConfigs() {
// empty
}
}

View File

@ -30,6 +30,7 @@ import org.apache.kafka.common.network.ChannelBuilders;
import org.apache.kafka.common.network.ListenerName; import org.apache.kafka.common.network.ListenerName;
import org.apache.kafka.common.network.NetworkReceive; import org.apache.kafka.common.network.NetworkReceive;
import org.apache.kafka.common.network.NetworkSend; import org.apache.kafka.common.network.NetworkSend;
import org.apache.kafka.common.network.ReauthenticationContext;
import org.apache.kafka.common.network.Send; import org.apache.kafka.common.network.Send;
import org.apache.kafka.common.network.TransportLayer; import org.apache.kafka.common.network.TransportLayer;
import org.apache.kafka.common.protocol.ApiKeys; import org.apache.kafka.common.protocol.ApiKeys;
@ -54,6 +55,7 @@ import org.apache.kafka.common.security.kerberos.KerberosName;
import org.apache.kafka.common.security.kerberos.KerberosShortNamer; import org.apache.kafka.common.security.kerberos.KerberosShortNamer;
import org.apache.kafka.common.security.scram.ScramLoginModule; import org.apache.kafka.common.security.scram.ScramLoginModule;
import org.apache.kafka.common.security.scram.internals.ScramMechanism; import org.apache.kafka.common.security.scram.internals.ScramMechanism;
import org.apache.kafka.common.utils.Time;
import org.apache.kafka.common.utils.Utils; import org.apache.kafka.common.utils.Utils;
import org.ietf.jgss.GSSContext; import org.ietf.jgss.GSSContext;
import org.ietf.jgss.GSSCredential; import org.ietf.jgss.GSSCredential;
@ -75,25 +77,40 @@ import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey; import java.nio.channels.SelectionKey;
import java.security.PrivilegedActionException; import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction; import java.security.PrivilegedExceptionAction;
import java.util.Date;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects;
import java.util.Set; import java.util.Set;
public class SaslServerAuthenticator implements Authenticator { public class SaslServerAuthenticator implements Authenticator {
// GSSAPI limits requests to 64K, but we allow a bit extra for custom SASL mechanisms // GSSAPI limits requests to 64K, but we allow a bit extra for custom SASL mechanisms
static final int MAX_RECEIVE_SIZE = 524288; static final int MAX_RECEIVE_SIZE = 524288;
private static final Logger LOG = LoggerFactory.getLogger(SaslServerAuthenticator.class); private static final Logger LOG = LoggerFactory.getLogger(SaslServerAuthenticator.class);
private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocate(0); private static final ByteBuffer EMPTY_BUFFER = ByteBuffer.allocate(0);
/**
* The internal state transitions for initial authentication of a channel on the
* server side are declared in order, starting with {@link #INITIAL_REQUEST} and
* ending in either {@link #COMPLETE} or {@link #FAILED}.
* <p>
* Re-authentication of a channel on the server side starts with the state
* {@link #REAUTH_PROCESS_HANDSHAKE}. It may then flow to
* {@link #REAUTH_BAD_MECHANISM} before a transition to {@link #FAILED} if
* re-authentication is attempted with a mechanism different than the original
* one; otherwise it joins the authentication flow at the {@link #AUTHENTICATE}
* state and likewise ends at either {@link #COMPLETE} or {@link #FAILED}.
*/
private enum SaslState { private enum SaslState {
INITIAL_REQUEST, // May be GSSAPI token, SaslHandshake or ApiVersions INITIAL_REQUEST, // May be GSSAPI token, SaslHandshake or ApiVersions for authentication
HANDSHAKE_OR_VERSIONS_REQUEST, // May be SaslHandshake or ApiVersions HANDSHAKE_OR_VERSIONS_REQUEST, // May be SaslHandshake or ApiVersions
HANDSHAKE_REQUEST, // After an ApiVersions request, next request must be SaslHandshake HANDSHAKE_REQUEST, // After an ApiVersions request, next request must be SaslHandshake
AUTHENTICATE, // Authentication tokens (SaslHandshake v1 and above indicate SaslAuthenticate headers) AUTHENTICATE, // Authentication tokens (SaslHandshake v1 and above indicate SaslAuthenticate headers)
COMPLETE, // Authentication completed successfully COMPLETE, // Authentication completed successfully
FAILED // Authentication failed FAILED, // Authentication failed
REAUTH_PROCESS_HANDSHAKE, // Initial state for re-authentication, processes SASL handshake request
REAUTH_BAD_MECHANISM, // When re-authentication requested with wrong mechanism, generate exception
} }
private final SecurityProtocol securityProtocol; private final SecurityProtocol securityProtocol;
@ -105,6 +122,9 @@ public class SaslServerAuthenticator implements Authenticator {
private final Map<String, ?> configs; private final Map<String, ?> configs;
private final KafkaPrincipalBuilder principalBuilder; private final KafkaPrincipalBuilder principalBuilder;
private final Map<String, AuthenticateCallbackHandler> callbackHandlers; private final Map<String, AuthenticateCallbackHandler> callbackHandlers;
private final Map<String, Long> connectionsMaxReauthMsByMechanism;
private final Time time;
private final ReauthInfo reauthInfo;
// Current SASL state // Current SASL state
private SaslState saslState = SaslState.INITIAL_REQUEST; private SaslState saslState = SaslState.INITIAL_REQUEST;
@ -129,7 +149,9 @@ public class SaslServerAuthenticator implements Authenticator {
KerberosShortNamer kerberosNameParser, KerberosShortNamer kerberosNameParser,
ListenerName listenerName, ListenerName listenerName,
SecurityProtocol securityProtocol, SecurityProtocol securityProtocol,
TransportLayer transportLayer) { TransportLayer transportLayer,
Map<String, Long> connectionsMaxReauthMsByMechanism,
Time time) {
this.callbackHandlers = callbackHandlers; this.callbackHandlers = callbackHandlers;
this.connectionId = connectionId; this.connectionId = connectionId;
this.subjects = subjects; this.subjects = subjects;
@ -137,6 +159,9 @@ public class SaslServerAuthenticator implements Authenticator {
this.securityProtocol = securityProtocol; this.securityProtocol = securityProtocol;
this.enableKafkaSaslAuthenticateHeaders = false; this.enableKafkaSaslAuthenticateHeaders = false;
this.transportLayer = transportLayer; this.transportLayer = transportLayer;
this.connectionsMaxReauthMsByMechanism = connectionsMaxReauthMsByMechanism;
this.time = time;
this.reauthInfo = new ReauthInfo();
this.configs = configs; this.configs = configs;
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
@ -149,6 +174,8 @@ public class SaslServerAuthenticator implements Authenticator {
throw new IllegalArgumentException("Callback handler not specified for SASL mechanism " + mechanism); throw new IllegalArgumentException("Callback handler not specified for SASL mechanism " + mechanism);
if (!subjects.containsKey(mechanism)) if (!subjects.containsKey(mechanism))
throw new IllegalArgumentException("Subject cannot be null for SASL mechanism " + mechanism); throw new IllegalArgumentException("Subject cannot be null for SASL mechanism " + mechanism);
LOG.debug("{} for mechanism={}: {}", BrokerSecurityConfigs.CONNECTIONS_MAX_REAUTH_MS, mechanism,
connectionsMaxReauthMsByMechanism.get(mechanism));
} }
// Note that the old principal builder does not support SASL, so we do not need to pass the // Note that the old principal builder does not support SASL, so we do not need to pass the
@ -224,52 +251,58 @@ public class SaslServerAuthenticator implements Authenticator {
*/ */
@Override @Override
public void authenticate() throws IOException { public void authenticate() throws IOException {
if (netOutBuffer != null && !flushNetOutBufferAndUpdateInterestOps()) if (saslState != SaslState.REAUTH_PROCESS_HANDSHAKE) {
return; if (netOutBuffer != null && !flushNetOutBufferAndUpdateInterestOps())
return;
if (saslServer != null && saslServer.isComplete()) {
setSaslState(SaslState.COMPLETE); if (saslServer != null && saslServer.isComplete()) {
return; setSaslState(SaslState.COMPLETE);
} return;
if (netInBuffer == null) netInBuffer = new NetworkReceive(MAX_RECEIVE_SIZE, connectionId);
netInBuffer.readFrom(transportLayer);
if (netInBuffer.complete()) {
netInBuffer.payload().rewind();
byte[] clientToken = new byte[netInBuffer.payload().remaining()];
netInBuffer.payload().get(clientToken, 0, clientToken.length);
netInBuffer = null; // reset the networkReceive as we read all the data.
try {
switch (saslState) {
case HANDSHAKE_OR_VERSIONS_REQUEST:
case HANDSHAKE_REQUEST:
handleKafkaRequest(clientToken);
break;
case INITIAL_REQUEST:
if (handleKafkaRequest(clientToken))
break;
// For default GSSAPI, fall through to authenticate using the client token as the first GSSAPI packet.
// This is required for interoperability with 0.9.0.x clients which do not send handshake request
case AUTHENTICATE:
handleSaslToken(clientToken);
// When the authentication exchange is complete and no more tokens are expected from the client,
// update SASL state. Current SASL state will be updated when outgoing writes to the client complete.
if (saslServer.isComplete())
setSaslState(SaslState.COMPLETE);
break;
default:
break;
}
} catch (AuthenticationException e) {
// Exception will be propagated after response is sent to client
setSaslState(SaslState.FAILED, e);
} catch (Exception e) {
// In the case of IOExceptions and other unexpected exceptions, fail immediately
saslState = SaslState.FAILED;
throw e;
} }
// allocate on heap (as opposed to any socket server memory pool)
if (netInBuffer == null) netInBuffer = new NetworkReceive(MAX_RECEIVE_SIZE, connectionId);
netInBuffer.readFrom(transportLayer);
if (!netInBuffer.complete())
return;
netInBuffer.payload().rewind();
}
byte[] clientToken = new byte[netInBuffer.payload().remaining()];
netInBuffer.payload().get(clientToken, 0, clientToken.length);
netInBuffer = null; // reset the networkReceive as we read all the data.
try {
switch (saslState) {
case REAUTH_PROCESS_HANDSHAKE:
case HANDSHAKE_OR_VERSIONS_REQUEST:
case HANDSHAKE_REQUEST:
handleKafkaRequest(clientToken);
break;
case REAUTH_BAD_MECHANISM:
throw new SaslAuthenticationException(reauthInfo.badMechanismErrorMessage);
case INITIAL_REQUEST:
if (handleKafkaRequest(clientToken))
break;
// For default GSSAPI, fall through to authenticate using the client token as the first GSSAPI packet.
// This is required for interoperability with 0.9.0.x clients which do not send handshake request
case AUTHENTICATE:
handleSaslToken(clientToken);
// When the authentication exchange is complete and no more tokens are expected from the client,
// update SASL state. Current SASL state will be updated when outgoing writes to the client complete.
if (saslServer.isComplete())
setSaslState(SaslState.COMPLETE);
break;
default:
break;
}
} catch (AuthenticationException e) {
// Exception will be propagated after response is sent to client
setSaslState(SaslState.FAILED, e);
} catch (Exception e) {
// In the case of IOExceptions and other unexpected exceptions, fail immediately
saslState = SaslState.FAILED;
LOG.debug("Failed during {}: {}", reauthInfo.authenticationOrReauthenticationText(), e.getMessage());
throw e;
} }
} }
@ -301,6 +334,38 @@ public class SaslServerAuthenticator implements Authenticator {
saslServer.dispose(); saslServer.dispose();
} }
@Override
public void reauthenticate(ReauthenticationContext reauthenticationContext) throws IOException {
NetworkReceive saslHandshakeReceive = reauthenticationContext.networkReceive();
if (saslHandshakeReceive == null)
throw new IllegalArgumentException(
"Invalid saslHandshakeReceive in server-side re-authentication context: null");
SaslServerAuthenticator previousSaslServerAuthenticator = (SaslServerAuthenticator) reauthenticationContext.previousAuthenticator();
reauthInfo.reauthenticating(previousSaslServerAuthenticator.saslMechanism,
previousSaslServerAuthenticator.principal(), reauthenticationContext.reauthenticationBeginNanos());
previousSaslServerAuthenticator.close();
netInBuffer = saslHandshakeReceive;
LOG.debug("Beginning re-authentication: {}", this);
netInBuffer.payload().rewind();
setSaslState(SaslState.REAUTH_PROCESS_HANDSHAKE);
authenticate();
}
@Override
public Long serverSessionExpirationTimeNanos() {
return reauthInfo.sessionExpirationTimeNanos;
}
@Override
public Long reauthenticationLatencyMs() {
return reauthInfo.reauthenticationLatencyMs();
}
@Override
public boolean connectedClientSupportsReauthentication() {
return reauthInfo.connectedClientSupportsReauthentication;
}
private void setSaslState(SaslState saslState) { private void setSaslState(SaslState saslState) {
setSaslState(saslState, null); setSaslState(saslState, null);
} }
@ -311,7 +376,7 @@ public class SaslServerAuthenticator implements Authenticator {
pendingException = exception; pendingException = exception;
} else { } else {
this.saslState = saslState; this.saslState = saslState;
LOG.debug("Set SASL server state to {}", saslState); LOG.debug("Set SASL server state to {} during {}", saslState, reauthInfo.authenticationOrReauthenticationText());
this.pendingSaslState = null; this.pendingSaslState = null;
this.pendingException = null; this.pendingException = null;
if (exception != null) if (exception != null)
@ -347,6 +412,8 @@ public class SaslServerAuthenticator implements Authenticator {
private void handleSaslToken(byte[] clientToken) throws IOException { private void handleSaslToken(byte[] clientToken) throws IOException {
if (!enableKafkaSaslAuthenticateHeaders) { if (!enableKafkaSaslAuthenticateHeaders) {
byte[] response = saslServer.evaluateResponse(clientToken); byte[] response = saslServer.evaluateResponse(clientToken);
if (reauthInfo.reauthenticating() && saslServer.isComplete())
reauthInfo.ensurePrincipalUnchanged(principal());
if (response != null) { if (response != null) {
netOutBuffer = new NetworkSend(connectionId, ByteBuffer.wrap(response)); netOutBuffer = new NetworkSend(connectionId, ByteBuffer.wrap(response));
flushNetOutBufferAndUpdateInterestOps(); flushNetOutBufferAndUpdateInterestOps();
@ -369,13 +436,24 @@ public class SaslServerAuthenticator implements Authenticator {
// This should not normally occur since clients typically check supported versions using ApiVersionsRequest // This should not normally occur since clients typically check supported versions using ApiVersionsRequest
throw new UnsupportedVersionException("Version " + version + " is not supported for apiKey " + apiKey); throw new UnsupportedVersionException("Version " + version + " is not supported for apiKey " + apiKey);
} }
/*
* The client sends multiple SASL_AUTHENTICATE requests, and the client is known
* to support the required version if any one of them indicates it supports that
* version.
*/
if (!reauthInfo.connectedClientSupportsReauthentication)
reauthInfo.connectedClientSupportsReauthentication = version > 0;
SaslAuthenticateRequest saslAuthenticateRequest = (SaslAuthenticateRequest) requestAndSize.request; SaslAuthenticateRequest saslAuthenticateRequest = (SaslAuthenticateRequest) requestAndSize.request;
try { try {
byte[] responseToken = saslServer.evaluateResponse(Utils.readBytes(saslAuthenticateRequest.saslAuthBytes())); byte[] responseToken = saslServer.evaluateResponse(Utils.readBytes(saslAuthenticateRequest.saslAuthBytes()));
if (reauthInfo.reauthenticating() && saslServer.isComplete())
reauthInfo.ensurePrincipalUnchanged(principal());
// For versions with SASL_AUTHENTICATE header, send a response to SASL_AUTHENTICATE request even if token is empty. // For versions with SASL_AUTHENTICATE header, send a response to SASL_AUTHENTICATE request even if token is empty.
ByteBuffer responseBuf = responseToken == null ? EMPTY_BUFFER : ByteBuffer.wrap(responseToken); ByteBuffer responseBuf = responseToken == null ? EMPTY_BUFFER : ByteBuffer.wrap(responseToken);
sendKafkaResponse(requestContext, new SaslAuthenticateResponse(Errors.NONE, null, responseBuf)); long sessionLifetimeMs = !saslServer.isComplete() ? 0L
: reauthInfo.calcCompletionTimesAndReturnSessionLifetimeMs();
sendKafkaResponse(requestContext, new SaslAuthenticateResponse(Errors.NONE, null, responseBuf, sessionLifetimeMs));
} catch (SaslAuthenticationException e) { } catch (SaslAuthenticationException e) {
buildResponseOnAuthenticateFailure(requestContext, buildResponseOnAuthenticateFailure(requestContext,
new SaslAuthenticateResponse(Errors.SASL_AUTHENTICATION_FAILED, e.getMessage())); new SaslAuthenticateResponse(Errors.SASL_AUTHENTICATION_FAILED, e.getMessage()));
@ -386,7 +464,10 @@ public class SaslServerAuthenticator implements Authenticator {
// Handle retriable Kerberos exceptions as I/O exceptions rather than authentication exceptions // Handle retriable Kerberos exceptions as I/O exceptions rather than authentication exceptions
throw e; throw e;
} else { } else {
String errorMessage = "Authentication failed due to invalid credentials with SASL mechanism " + saslMechanism; String errorMessage = "Authentication failed during "
+ reauthInfo.authenticationOrReauthenticationText()
+ " due to invalid credentials with SASL mechanism " + saslMechanism + ": "
+ e.getMessage();
sendKafkaResponse(requestContext, new SaslAuthenticateResponse(Errors.SASL_AUTHENTICATION_FAILED, sendKafkaResponse(requestContext, new SaslAuthenticateResponse(Errors.SASL_AUTHENTICATION_FAILED,
errorMessage)); errorMessage));
throw new SaslAuthenticationException(errorMessage, e); throw new SaslAuthenticationException(errorMessage, e);
@ -414,7 +495,7 @@ public class SaslServerAuthenticator implements Authenticator {
if (apiKey != ApiKeys.API_VERSIONS && apiKey != ApiKeys.SASL_HANDSHAKE) if (apiKey != ApiKeys.API_VERSIONS && apiKey != ApiKeys.SASL_HANDSHAKE)
throw new IllegalSaslStateException("Unexpected Kafka request of type " + apiKey + " during SASL handshake."); throw new IllegalSaslStateException("Unexpected Kafka request of type " + apiKey + " during SASL handshake.");
LOG.debug("Handling Kafka request {}", apiKey); LOG.debug("Handling Kafka request {} during {}", apiKey, reauthInfo.authenticationOrReauthenticationText());
RequestContext requestContext = new RequestContext(header, connectionId, clientAddress(), RequestContext requestContext = new RequestContext(header, connectionId, clientAddress(),
@ -446,7 +527,8 @@ public class SaslServerAuthenticator implements Authenticator {
} else } else
throw e; throw e;
} }
if (clientMechanism != null) { if (clientMechanism != null && (!reauthInfo.reauthenticating()
|| reauthInfo.saslMechanismUnchanged(clientMechanism))) {
createSaslServer(clientMechanism); createSaslServer(clientMechanism);
setSaslState(SaslState.AUTHENTICATE); setSaslState(SaslState.AUTHENTICATE);
} }
@ -517,4 +599,116 @@ public class SaslServerAuthenticator implements Authenticator {
netOutBuffer = send; netOutBuffer = send;
flushNetOutBufferAndUpdateInterestOps(); flushNetOutBufferAndUpdateInterestOps();
} }
/**
* Information related to re-authentication
*/
private class ReauthInfo {
public String previousSaslMechanism;
public KafkaPrincipal previousKafkaPrincipal;
public long reauthenticationBeginNanos;
public Long sessionExpirationTimeNanos;
public boolean connectedClientSupportsReauthentication;
public long authenticationEndNanos;
public String badMechanismErrorMessage;
public void reauthenticating(String previousSaslMechanism, KafkaPrincipal previousKafkaPrincipal,
long reauthenticationBeginNanos) {
this.previousSaslMechanism = Objects.requireNonNull(previousSaslMechanism);
this.previousKafkaPrincipal = Objects.requireNonNull(previousKafkaPrincipal);
this.reauthenticationBeginNanos = reauthenticationBeginNanos;
}
public boolean reauthenticating() {
return previousSaslMechanism != null;
}
public String authenticationOrReauthenticationText() {
return reauthenticating() ? "re-authentication" : "authentication";
}
public void ensurePrincipalUnchanged(KafkaPrincipal reauthenticatedKafkaPrincipal) throws SaslAuthenticationException {
if (!previousKafkaPrincipal.equals(reauthenticatedKafkaPrincipal)) {
throw new SaslAuthenticationException(String.format(
"Cannot change principals during re-authentication from %s.%s: %s.%s",
previousKafkaPrincipal.getPrincipalType(), previousKafkaPrincipal.getName(),
reauthenticatedKafkaPrincipal.getPrincipalType(), reauthenticatedKafkaPrincipal.getName()));
}
}
/*
* We define the REAUTH_BAD_MECHANISM state because the failed re-authentication
* metric does not get updated if we send back an error immediately upon the
* start of re-authentication.
*/
public boolean saslMechanismUnchanged(String clientMechanism) {
if (previousSaslMechanism.equals(clientMechanism))
return true;
badMechanismErrorMessage = String.format(
"SASL mechanism '%s' requested by client is not supported for re-authentication of mechanism '%s'",
clientMechanism, previousSaslMechanism);
LOG.debug(badMechanismErrorMessage);
setSaslState(SaslState.REAUTH_BAD_MECHANISM);
return false;
}
private long calcCompletionTimesAndReturnSessionLifetimeMs() {
long retvalSessionLifetimeMs = 0L;
long authenticationEndMs = time.milliseconds();
authenticationEndNanos = time.nanoseconds();
Long credentialExpirationMs = (Long) saslServer
.getNegotiatedProperty(SaslInternalConfigs.CREDENTIAL_LIFETIME_MS_SASL_NEGOTIATED_PROPERTY_KEY);
Long connectionsMaxReauthMs = connectionsMaxReauthMsByMechanism.get(saslMechanism);
if (credentialExpirationMs != null || connectionsMaxReauthMs != null) {
if (credentialExpirationMs == null)
retvalSessionLifetimeMs = zeroIfNegative(connectionsMaxReauthMs.longValue());
else if (connectionsMaxReauthMs == null)
retvalSessionLifetimeMs = zeroIfNegative(credentialExpirationMs.longValue() - authenticationEndMs);
else
retvalSessionLifetimeMs = zeroIfNegative(
Math.min(credentialExpirationMs.longValue() - authenticationEndMs,
connectionsMaxReauthMs.longValue()));
if (retvalSessionLifetimeMs > 0L)
sessionExpirationTimeNanos = Long
.valueOf(authenticationEndNanos + 1000 * 1000 * retvalSessionLifetimeMs);
}
if (credentialExpirationMs != null) {
if (sessionExpirationTimeNanos != null)
LOG.debug(
"Authentication complete; session max lifetime from broker config={} ms, credential expiration={} ({} ms); session expiration = {} ({} ms), sending {} ms to client",
connectionsMaxReauthMs, new Date(credentialExpirationMs),
Long.valueOf(credentialExpirationMs.longValue() - authenticationEndMs),
new Date(authenticationEndMs + retvalSessionLifetimeMs), retvalSessionLifetimeMs,
retvalSessionLifetimeMs);
else
LOG.debug(
"Authentication complete; session max lifetime from broker config={} ms, credential expiration={} ({} ms); no session expiration, sending 0 ms to client",
connectionsMaxReauthMs, new Date(credentialExpirationMs),
Long.valueOf(credentialExpirationMs.longValue() - authenticationEndMs));
} else {
if (sessionExpirationTimeNanos != null)
LOG.debug(
"Authentication complete; session max lifetime from broker config={} ms, no credential expiration; session expiration = {} ({} ms), sending {} ms to client",
connectionsMaxReauthMs, new Date(authenticationEndMs + retvalSessionLifetimeMs),
retvalSessionLifetimeMs, retvalSessionLifetimeMs);
else
LOG.debug(
"Authentication complete; session max lifetime from broker config={} ms, no credential expiration; no session expiration, sending 0 ms to client",
connectionsMaxReauthMs);
}
return retvalSessionLifetimeMs;
}
public Long reauthenticationLatencyMs() {
if (!reauthenticating())
return null;
// record at least 1 ms if there is some latency
long latencyNanos = authenticationEndNanos - reauthenticationBeginNanos;
return latencyNanos == 0L ? 0L : Math.max(1L, Long.valueOf(Math.round(latencyNanos / 1000.0 / 1000.0)));
}
private long zeroIfNegative(long value) {
return Math.max(0L, value);
}
}
} }

View File

@ -32,6 +32,7 @@ import javax.security.sasl.SaslServerFactory;
import org.apache.kafka.common.errors.SaslAuthenticationException; import org.apache.kafka.common.errors.SaslAuthenticationException;
import org.apache.kafka.common.security.auth.SaslExtensions; import org.apache.kafka.common.security.auth.SaslExtensions;
import org.apache.kafka.common.security.authenticator.SaslInternalConfigs;
import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerExtensionsValidatorCallback; import org.apache.kafka.common.security.oauthbearer.OAuthBearerExtensionsValidatorCallback;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule; import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
@ -118,7 +119,8 @@ public class OAuthBearerSaslServer implements SaslServer {
throw new IllegalStateException("Authentication exchange has not completed"); throw new IllegalStateException("Authentication exchange has not completed");
if (NEGOTIATED_PROPERTY_KEY_TOKEN.equals(propName)) if (NEGOTIATED_PROPERTY_KEY_TOKEN.equals(propName))
return tokenForNegotiatedProperty; return tokenForNegotiatedProperty;
if (SaslInternalConfigs.CREDENTIAL_LIFETIME_MS_SASL_NEGOTIATED_PROPERTY_KEY.equals(propName))
return tokenForNegotiatedProperty.lifetimeMs();
return extensions.map().get(propName); return extensions.map().get(propName);
} }

View File

@ -33,6 +33,7 @@ import javax.security.sasl.SaslServerFactory;
import org.apache.kafka.common.errors.AuthenticationException; import org.apache.kafka.common.errors.AuthenticationException;
import org.apache.kafka.common.errors.IllegalSaslStateException; import org.apache.kafka.common.errors.IllegalSaslStateException;
import org.apache.kafka.common.errors.SaslAuthenticationException; import org.apache.kafka.common.errors.SaslAuthenticationException;
import org.apache.kafka.common.security.authenticator.SaslInternalConfigs;
import org.apache.kafka.common.security.scram.ScramCredential; import org.apache.kafka.common.security.scram.ScramCredential;
import org.apache.kafka.common.security.scram.ScramCredentialCallback; import org.apache.kafka.common.security.scram.ScramCredentialCallback;
import org.apache.kafka.common.security.scram.ScramLoginModule; import org.apache.kafka.common.security.scram.ScramLoginModule;
@ -74,6 +75,7 @@ public class ScramSaslServer implements SaslServer {
private ScramExtensions scramExtensions; private ScramExtensions scramExtensions;
private ScramCredential scramCredential; private ScramCredential scramCredential;
private String authorizationId; private String authorizationId;
private Long tokenExpiryTimestamp;
public ScramSaslServer(ScramMechanism mechanism, Map<String, ?> props, CallbackHandler callbackHandler) throws NoSuchAlgorithmException { public ScramSaslServer(ScramMechanism mechanism, Map<String, ?> props, CallbackHandler callbackHandler) throws NoSuchAlgorithmException {
this.mechanism = mechanism; this.mechanism = mechanism;
@ -115,10 +117,12 @@ public class ScramSaslServer implements SaslServer {
if (tokenCallback.tokenOwner() == null) if (tokenCallback.tokenOwner() == null)
throw new SaslException("Token Authentication failed: Invalid tokenId : " + username); throw new SaslException("Token Authentication failed: Invalid tokenId : " + username);
this.authorizationId = tokenCallback.tokenOwner(); this.authorizationId = tokenCallback.tokenOwner();
this.tokenExpiryTimestamp = tokenCallback.tokenExpiryTimestamp();
} else { } else {
credentialCallback = new ScramCredentialCallback(); credentialCallback = new ScramCredentialCallback();
callbackHandler.handle(new Callback[]{nameCallback, credentialCallback}); callbackHandler.handle(new Callback[]{nameCallback, credentialCallback});
this.authorizationId = username; this.authorizationId = username;
this.tokenExpiryTimestamp = null;
} }
this.scramCredential = credentialCallback.scramCredential(); this.scramCredential = credentialCallback.scramCredential();
if (scramCredential == null) if (scramCredential == null)
@ -181,7 +185,8 @@ public class ScramSaslServer implements SaslServer {
public Object getNegotiatedProperty(String propName) { public Object getNegotiatedProperty(String propName) {
if (!isComplete()) if (!isComplete())
throw new IllegalStateException("Authentication exchange has not completed"); throw new IllegalStateException("Authentication exchange has not completed");
if (SaslInternalConfigs.CREDENTIAL_LIFETIME_MS_SASL_NEGOTIATED_PROPERTY_KEY.equals(propName))
return tokenExpiryTimestamp; // will be null if token not used
if (SUPPORTED_EXTENSIONS.contains(propName)) if (SUPPORTED_EXTENSIONS.contains(propName))
return scramExtensions.map().get(propName); return scramExtensions.map().get(propName);
else else

View File

@ -28,6 +28,7 @@ import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
import org.apache.kafka.common.security.authenticator.CredentialCache; import org.apache.kafka.common.security.authenticator.CredentialCache;
import org.apache.kafka.common.security.scram.ScramCredential; import org.apache.kafka.common.security.scram.ScramCredential;
import org.apache.kafka.common.security.scram.ScramCredentialCallback; import org.apache.kafka.common.security.scram.ScramCredentialCallback;
import org.apache.kafka.common.security.token.delegation.TokenInformation;
import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache; import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache;
import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCredentialCallback; import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCredentialCallback;
@ -58,6 +59,9 @@ public class ScramServerCallbackHandler implements AuthenticateCallbackHandler {
DelegationTokenCredentialCallback tokenCallback = (DelegationTokenCredentialCallback) callback; DelegationTokenCredentialCallback tokenCallback = (DelegationTokenCredentialCallback) callback;
tokenCallback.scramCredential(tokenCache.credential(saslMechanism, username)); tokenCallback.scramCredential(tokenCache.credential(saslMechanism, username));
tokenCallback.tokenOwner(tokenCache.owner(username)); tokenCallback.tokenOwner(tokenCache.owner(username));
TokenInformation tokenInfo = tokenCache.token(username);
if (tokenInfo != null)
tokenCallback.tokenExpiryTimestamp(tokenInfo.expiryTimestamp());
} else if (callback instanceof ScramCredentialCallback) { } else if (callback instanceof ScramCredentialCallback) {
ScramCredentialCallback sc = (ScramCredentialCallback) callback; ScramCredentialCallback sc = (ScramCredentialCallback) callback;
sc.scramCredential(credentialCache.get(username)); sc.scramCredential(credentialCache.get(username));

View File

@ -20,6 +20,7 @@ import org.apache.kafka.common.security.scram.ScramCredentialCallback;
public class DelegationTokenCredentialCallback extends ScramCredentialCallback { public class DelegationTokenCredentialCallback extends ScramCredentialCallback {
private String tokenOwner; private String tokenOwner;
private Long tokenExpiryTimestamp;
public void tokenOwner(String tokenOwner) { public void tokenOwner(String tokenOwner) {
this.tokenOwner = tokenOwner; this.tokenOwner = tokenOwner;
@ -28,4 +29,12 @@ public class DelegationTokenCredentialCallback extends ScramCredentialCallback {
public String tokenOwner() { public String tokenOwner() {
return tokenOwner; return tokenOwner;
} }
public void tokenExpiryTimestamp(Long tokenExpiryTimestamp) {
this.tokenExpiryTimestamp = tokenExpiryTimestamp;
}
public Long tokenExpiryTimestamp() {
return tokenExpiryTimestamp;
}
} }

View File

@ -28,6 +28,7 @@ import org.apache.kafka.common.metrics.Metrics;
import org.apache.kafka.common.security.auth.SecurityProtocol; import org.apache.kafka.common.security.auth.SecurityProtocol;
import org.apache.kafka.common.utils.LogContext; import org.apache.kafka.common.utils.LogContext;
import org.apache.kafka.common.security.authenticator.CredentialCache; import org.apache.kafka.common.security.authenticator.CredentialCache;
import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache;
import org.apache.kafka.common.utils.Time; import org.apache.kafka.common.utils.Time;
import org.apache.kafka.common.utils.Utils; import org.apache.kafka.common.utils.Utils;
import org.apache.kafka.test.TestUtils; import org.apache.kafka.test.TestUtils;
@ -50,6 +51,15 @@ public class NetworkTestUtils {
return server; return server;
} }
public static NioEchoServer createEchoServer(ListenerName listenerName, SecurityProtocol securityProtocol,
AbstractConfig serverConfig, CredentialCache credentialCache,
int failedAuthenticationDelayMs, Time time, DelegationTokenCache tokenCache) throws Exception {
NioEchoServer server = new NioEchoServer(listenerName, securityProtocol, serverConfig, "localhost",
null, credentialCache, failedAuthenticationDelayMs, time, tokenCache);
server.start();
return server;
}
public static Selector createSelector(ChannelBuilder channelBuilder, Time time) { public static Selector createSelector(ChannelBuilder channelBuilder, Time time) {
return new Selector(5000, new Metrics(), time, "MetricGroup", channelBuilder, new LogContext()); return new Selector(5000, new Metrics(), time, "MetricGroup", channelBuilder, new LogContext());
} }

View File

@ -22,13 +22,13 @@ import org.apache.kafka.common.MetricName;
import org.apache.kafka.common.config.AbstractConfig; import org.apache.kafka.common.config.AbstractConfig;
import org.apache.kafka.common.metrics.KafkaMetric; import org.apache.kafka.common.metrics.KafkaMetric;
import org.apache.kafka.common.metrics.Metrics; import org.apache.kafka.common.metrics.Metrics;
import org.apache.kafka.common.protocol.ApiKeys;
import org.apache.kafka.common.security.auth.SecurityProtocol; import org.apache.kafka.common.security.auth.SecurityProtocol;
import org.apache.kafka.common.security.authenticator.CredentialCache; import org.apache.kafka.common.security.authenticator.CredentialCache;
import org.apache.kafka.common.security.scram.ScramCredential; import org.apache.kafka.common.security.scram.ScramCredential;
import org.apache.kafka.common.security.scram.internals.ScramMechanism; import org.apache.kafka.common.security.scram.internals.ScramMechanism;
import org.apache.kafka.common.utils.LogContext; import org.apache.kafka.common.utils.LogContext;
import org.apache.kafka.common.utils.Time; import org.apache.kafka.common.utils.Time;
import org.apache.kafka.test.TestCondition;
import org.apache.kafka.test.TestUtils; import org.apache.kafka.test.TestUtils;
import java.io.IOException; import java.io.IOException;
@ -40,9 +40,12 @@ import java.nio.channels.SocketChannel;
import java.nio.channels.WritableByteChannel; import java.nio.channels.WritableByteChannel;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.EnumSet;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Locale;
import java.util.Map; import java.util.Map;
import java.util.Set;
import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache; import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache;
@ -52,6 +55,19 @@ import org.apache.kafka.common.security.token.delegation.internals.DelegationTok
* *
*/ */
public class NioEchoServer extends Thread { public class NioEchoServer extends Thread {
public enum MetricType {
TOTAL, RATE, AVG, MAX;
private final String metricNameSuffix;
private MetricType() {
metricNameSuffix = "-" + name().toLowerCase(Locale.ROOT);
}
public String metricNameSuffix() {
return metricNameSuffix;
}
}
private static final double EPS = 0.0001; private static final double EPS = 0.0001;
@ -67,7 +83,8 @@ public class NioEchoServer extends Thread {
private volatile int numSent = 0; private volatile int numSent = 0;
private volatile boolean closeKafkaChannels; private volatile boolean closeKafkaChannels;
private final DelegationTokenCache tokenCache; private final DelegationTokenCache tokenCache;
private final Time time;
public NioEchoServer(ListenerName listenerName, SecurityProtocol securityProtocol, AbstractConfig config, public NioEchoServer(ListenerName listenerName, SecurityProtocol securityProtocol, AbstractConfig config,
String serverHost, ChannelBuilder channelBuilder, CredentialCache credentialCache, Time time) throws Exception { String serverHost, ChannelBuilder channelBuilder, CredentialCache credentialCache, Time time) throws Exception {
this(listenerName, securityProtocol, config, serverHost, channelBuilder, credentialCache, 100, time); this(listenerName, securityProtocol, config, serverHost, channelBuilder, credentialCache, 100, time);
@ -76,6 +93,13 @@ public class NioEchoServer extends Thread {
public NioEchoServer(ListenerName listenerName, SecurityProtocol securityProtocol, AbstractConfig config, public NioEchoServer(ListenerName listenerName, SecurityProtocol securityProtocol, AbstractConfig config,
String serverHost, ChannelBuilder channelBuilder, CredentialCache credentialCache, String serverHost, ChannelBuilder channelBuilder, CredentialCache credentialCache,
int failedAuthenticationDelayMs, Time time) throws Exception { int failedAuthenticationDelayMs, Time time) throws Exception {
this(listenerName, securityProtocol, config, serverHost, channelBuilder, credentialCache, 100, time,
new DelegationTokenCache(ScramMechanism.mechanismNames()));
}
public NioEchoServer(ListenerName listenerName, SecurityProtocol securityProtocol, AbstractConfig config,
String serverHost, ChannelBuilder channelBuilder, CredentialCache credentialCache,
int failedAuthenticationDelayMs, Time time, DelegationTokenCache tokenCache) throws Exception {
super("echoserver"); super("echoserver");
setDaemon(true); setDaemon(true);
serverSocketChannel = ServerSocketChannel.open(); serverSocketChannel = ServerSocketChannel.open();
@ -85,7 +109,7 @@ public class NioEchoServer extends Thread {
this.socketChannels = Collections.synchronizedList(new ArrayList<SocketChannel>()); this.socketChannels = Collections.synchronizedList(new ArrayList<SocketChannel>());
this.newChannels = Collections.synchronizedList(new ArrayList<SocketChannel>()); this.newChannels = Collections.synchronizedList(new ArrayList<SocketChannel>());
this.credentialCache = credentialCache; this.credentialCache = credentialCache;
this.tokenCache = new DelegationTokenCache(ScramMechanism.mechanismNames()); this.tokenCache = tokenCache;
if (securityProtocol == SecurityProtocol.SASL_PLAINTEXT || securityProtocol == SecurityProtocol.SASL_SSL) { if (securityProtocol == SecurityProtocol.SASL_PLAINTEXT || securityProtocol == SecurityProtocol.SASL_SSL) {
for (String mechanism : ScramMechanism.mechanismNames()) { for (String mechanism : ScramMechanism.mechanismNames()) {
if (credentialCache.cache(mechanism, ScramCredential.class) == null) if (credentialCache.cache(mechanism, ScramCredential.class) == null)
@ -93,10 +117,11 @@ public class NioEchoServer extends Thread {
} }
} }
if (channelBuilder == null) if (channelBuilder == null)
channelBuilder = ChannelBuilders.serverChannelBuilder(listenerName, false, securityProtocol, config, credentialCache, tokenCache); channelBuilder = ChannelBuilders.serverChannelBuilder(listenerName, false, securityProtocol, config, credentialCache, tokenCache, time);
this.metrics = new Metrics(); this.metrics = new Metrics();
this.selector = new Selector(10000, failedAuthenticationDelayMs, metrics, time, "MetricGroup", channelBuilder, new LogContext()); this.selector = new Selector(10000, failedAuthenticationDelayMs, metrics, time, "MetricGroup", channelBuilder, new LogContext());
acceptorThread = new AcceptorThread(); acceptorThread = new AcceptorThread();
this.time = time;
} }
public int port() { public int port() {
@ -111,7 +136,6 @@ public class NioEchoServer extends Thread {
return tokenCache; return tokenCache;
} }
@SuppressWarnings("deprecation")
public double metricValue(String name) { public double metricValue(String name) {
for (Map.Entry<MetricName, KafkaMetric> entry : metrics.metrics().entrySet()) { for (Map.Entry<MetricName, KafkaMetric> entry : metrics.metrics().entrySet()) {
if (entry.getKey().name().equals(name)) if (entry.getKey().name().equals(name))
@ -122,29 +146,52 @@ public class NioEchoServer extends Thread {
public void verifyAuthenticationMetrics(int successfulAuthentications, final int failedAuthentications) public void verifyAuthenticationMetrics(int successfulAuthentications, final int failedAuthentications)
throws InterruptedException { throws InterruptedException {
waitForMetric("successful-authentication", successfulAuthentications); waitForMetrics("successful-authentication", successfulAuthentications,
waitForMetric("failed-authentication", failedAuthentications); EnumSet.of(MetricType.TOTAL, MetricType.RATE));
waitForMetrics("failed-authentication", failedAuthentications, EnumSet.of(MetricType.TOTAL, MetricType.RATE));
}
public void verifyReauthenticationMetrics(int successfulReauthentications, final int failedReauthentications)
throws InterruptedException {
waitForMetrics("successful-reauthentication", successfulReauthentications,
EnumSet.of(MetricType.TOTAL, MetricType.RATE));
waitForMetrics("failed-reauthentication", failedReauthentications,
EnumSet.of(MetricType.TOTAL, MetricType.RATE));
waitForMetrics("successful-authentication-no-reauth", 0, EnumSet.of(MetricType.TOTAL));
waitForMetrics("reauthentication-latency", Math.signum(successfulReauthentications),
EnumSet.of(MetricType.MAX, MetricType.AVG));
}
public void verifyAuthenticationNoReauthMetric(int successfulAuthenticationNoReauths) throws InterruptedException {
waitForMetrics("successful-authentication-no-reauth", successfulAuthenticationNoReauths,
EnumSet.of(MetricType.TOTAL));
} }
public void waitForMetric(String name, final double expectedValue) throws InterruptedException { public void waitForMetric(String name, final double expectedValue) throws InterruptedException {
final String totalName = name + "-total"; waitForMetrics(name, expectedValue, EnumSet.of(MetricType.TOTAL, MetricType.RATE));
final String rateName = name + "-rate"; }
if (expectedValue == 0.0) {
assertEquals(expectedValue, metricValue(totalName), EPS); public void waitForMetrics(String namePrefix, final double expectedValue, Set<MetricType> metricTypes)
assertEquals(expectedValue, metricValue(rateName), EPS); throws InterruptedException {
} else { long maxAggregateWaitMs = 15000;
TestUtils.waitForCondition(new TestCondition() { long startMs = time.milliseconds();
@Override for (MetricType metricType : metricTypes) {
public boolean conditionMet() { long currentElapsedMs = time.milliseconds() - startMs;
return Math.abs(metricValue(totalName) - expectedValue) <= EPS; long thisMaxWaitMs = maxAggregateWaitMs - currentElapsedMs;
} String metricName = namePrefix + metricType.metricNameSuffix();
}, "Metric not updated " + totalName); if (expectedValue == 0.0)
TestUtils.waitForCondition(new TestCondition() { assertEquals(
@Override "Metric not updated " + metricName + " expected:<" + expectedValue + "> but was:<"
public boolean conditionMet() { + metricValue(metricName) + ">",
return metricValue(rateName) > 0.0; metricType == MetricType.MAX ? Double.NEGATIVE_INFINITY : 0d, metricValue(metricName), EPS);
} else if (metricType == MetricType.TOTAL)
}, "Metric not updated " + rateName); TestUtils.waitForCondition(() -> Math.abs(metricValue(metricName) - expectedValue) <= EPS,
thisMaxWaitMs, () -> "Metric not updated " + metricName + " expected:<" + expectedValue
+ "> but was:<" + metricValue(metricName) + ">");
else
TestUtils.waitForCondition(() -> metricValue(metricName) > 0.0, thisMaxWaitMs,
() -> "Metric not updated " + metricName + " expected:<a positive number> but was:<"
+ metricValue(metricName) + ">");
} }
} }
@ -170,15 +217,17 @@ public class NioEchoServer extends Thread {
List<NetworkReceive> completedReceives = selector.completedReceives(); List<NetworkReceive> completedReceives = selector.completedReceives();
for (NetworkReceive rcv : completedReceives) { for (NetworkReceive rcv : completedReceives) {
KafkaChannel channel = channel(rcv.source()); KafkaChannel channel = channel(rcv.source());
String channelId = channel.id(); if (!maybeBeginServerReauthentication(channel, rcv, time)) {
selector.mute(channelId); String channelId = channel.id();
NetworkSend send = new NetworkSend(rcv.source(), rcv.payload()); selector.mute(channelId);
if (outputChannel == null) NetworkSend send = new NetworkSend(rcv.source(), rcv.payload());
selector.send(send); if (outputChannel == null)
else { selector.send(send);
for (ByteBuffer buffer : send.buffers) else {
outputChannel.write(buffer); for (ByteBuffer buffer : send.buffers)
selector.unmute(channelId); outputChannel.write(buffer);
selector.unmute(channelId);
}
} }
} }
for (Send send : selector.completedSends()) { for (Send send : selector.completedSends()) {
@ -195,6 +244,17 @@ public class NioEchoServer extends Thread {
return numSent; return numSent;
} }
private static boolean maybeBeginServerReauthentication(KafkaChannel channel, NetworkReceive networkReceive, Time time) {
try {
if (TestUtils.apiKeyFrom(networkReceive) == ApiKeys.SASL_HANDSHAKE) {
return channel.maybeBeginServerReauthentication(networkReceive, () -> time.nanoseconds());
}
} catch (Exception e) {
// ignore
}
return false;
}
private String id(SocketChannel channel) { private String id(SocketChannel channel) {
return channel.socket().getLocalAddress().getHostAddress() + ":" + channel.socket().getLocalPort() + "-" + return channel.socket().getLocalAddress().getHostAddress() + ":" + channel.socket().getLocalPort() + "-" +
channel.socket().getInetAddress().getHostAddress() + ":" + channel.socket().getPort(); channel.socket().getInetAddress().getHostAddress() + ":" + channel.socket().getPort();

View File

@ -22,6 +22,7 @@ import org.apache.kafka.common.security.auth.SecurityProtocol;
import org.apache.kafka.common.security.JaasContext; import org.apache.kafka.common.security.JaasContext;
import org.apache.kafka.common.security.authenticator.TestJaasConfig; import org.apache.kafka.common.security.authenticator.TestJaasConfig;
import org.apache.kafka.common.security.plain.PlainLoginModule; import org.apache.kafka.common.security.plain.PlainLoginModule;
import org.apache.kafka.common.utils.Time;
import org.junit.Test; import org.junit.Test;
import java.util.Collections; import java.util.Collections;
@ -74,7 +75,7 @@ public class SaslChannelBuilderTest {
JaasContext jaasContext = new JaasContext("jaasContext", JaasContext.Type.SERVER, jaasConfig, null); JaasContext jaasContext = new JaasContext("jaasContext", JaasContext.Type.SERVER, jaasConfig, null);
Map<String, JaasContext> jaasContexts = Collections.singletonMap("PLAIN", jaasContext); Map<String, JaasContext> jaasContexts = Collections.singletonMap("PLAIN", jaasContext);
return new SaslChannelBuilder(Mode.CLIENT, jaasContexts, securityProtocol, new ListenerName("PLAIN"), return new SaslChannelBuilder(Mode.CLIENT, jaasContexts, securityProtocol, new ListenerName("PLAIN"),
false, "PLAIN", true, null, null); false, "PLAIN", true, null, null, Time.SYSTEM);
} }
} }

View File

@ -893,7 +893,7 @@ public class SslTransportLayerTest {
TestSecurityConfig config = new TestSecurityConfig(sslServerConfigs); TestSecurityConfig config = new TestSecurityConfig(sslServerConfigs);
ListenerName listenerName = ListenerName.forSecurityProtocol(securityProtocol); ListenerName listenerName = ListenerName.forSecurityProtocol(securityProtocol);
ChannelBuilder serverChannelBuilder = ChannelBuilders.serverChannelBuilder(listenerName, ChannelBuilder serverChannelBuilder = ChannelBuilders.serverChannelBuilder(listenerName,
false, securityProtocol, config, null, null); false, securityProtocol, config, null, null, time);
server = new NioEchoServer(listenerName, securityProtocol, config, server = new NioEchoServer(listenerName, securityProtocol, config,
"localhost", serverChannelBuilder, null, time); "localhost", serverChannelBuilder, null, time);
server.start(); server.start();
@ -953,7 +953,7 @@ public class SslTransportLayerTest {
TestSecurityConfig config = new TestSecurityConfig(sslServerConfigs); TestSecurityConfig config = new TestSecurityConfig(sslServerConfigs);
ListenerName listenerName = ListenerName.forSecurityProtocol(securityProtocol); ListenerName listenerName = ListenerName.forSecurityProtocol(securityProtocol);
ChannelBuilder serverChannelBuilder = ChannelBuilders.serverChannelBuilder(listenerName, ChannelBuilder serverChannelBuilder = ChannelBuilders.serverChannelBuilder(listenerName,
false, securityProtocol, config, null, null); false, securityProtocol, config, null, null, time);
server = new NioEchoServer(listenerName, securityProtocol, config, server = new NioEchoServer(listenerName, securityProtocol, config,
"localhost", serverChannelBuilder, null, time); "localhost", serverChannelBuilder, null, time);
server.start(); server.start();

View File

@ -49,8 +49,8 @@ public class ApiKeysTest {
* <ul> * <ul>
* <li> Cluster actions used only for inter-broker are throttled only if unauthorized * <li> Cluster actions used only for inter-broker are throttled only if unauthorized
* <li> SASL_HANDSHAKE and SASL_AUTHENTICATE are not throttled when used for authentication * <li> SASL_HANDSHAKE and SASL_AUTHENTICATE are not throttled when used for authentication
* when a connection is established. At any other time, this request returns an error * when a connection is established or for re-authentication thereafter; these requests
* response that may be throttled. * return an error response that may be throttled if they are sent otherwise.
* </ul> * </ul>
*/ */
@Test @Test

View File

@ -167,6 +167,10 @@ public class RequestResponseTest {
checkRequest(createSaslHandshakeRequest()); checkRequest(createSaslHandshakeRequest());
checkErrorResponse(createSaslHandshakeRequest(), new UnknownServerException()); checkErrorResponse(createSaslHandshakeRequest(), new UnknownServerException());
checkResponse(createSaslHandshakeResponse(), 0); checkResponse(createSaslHandshakeResponse(), 0);
checkRequest(createSaslAuthenticateRequest());
checkErrorResponse(createSaslAuthenticateRequest(), new UnknownServerException());
checkResponse(createSaslAuthenticateResponse(), 0);
checkResponse(createSaslAuthenticateResponse(), 1);
checkRequest(createApiVersionRequest()); checkRequest(createApiVersionRequest());
checkErrorResponse(createApiVersionRequest(), new UnknownServerException()); checkErrorResponse(createApiVersionRequest(), new UnknownServerException());
checkResponse(createApiVersionResponse(), 0); checkResponse(createApiVersionResponse(), 0);
@ -345,9 +349,19 @@ public class RequestResponseTest {
private void checkRequest(AbstractRequest req) throws Exception { private void checkRequest(AbstractRequest req) throws Exception {
// Check that we can serialize, deserialize and serialize again // Check that we can serialize, deserialize and serialize again
// We don't check for equality or hashCode because it is likely to fail for any request containing a HashMap // We don't check for equality or hashCode because it is likely to fail for any request containing a HashMap
checkRequest(req, false);
}
private void checkRequest(AbstractRequest req, boolean checkEqualityAndHashCode) throws Exception {
// Check that we can serialize, deserialize and serialize again
// Check for equality and hashCode only if indicated
Struct struct = req.toStruct(); Struct struct = req.toStruct();
AbstractRequest deserialized = (AbstractRequest) deserialize(req, struct, req.version()); AbstractRequest deserialized = (AbstractRequest) deserialize(req, struct, req.version());
deserialized.toStruct(); Struct struct2 = deserialized.toStruct();
if (checkEqualityAndHashCode) {
assertEquals(struct, struct2);
assertEquals(struct.hashCode(), struct2.hashCode());
}
} }
private void checkResponse(AbstractResponse response, int version) throws Exception { private void checkResponse(AbstractResponse response, int version) throws Exception {
@ -355,7 +369,7 @@ public class RequestResponseTest {
// We don't check for equality or hashCode because it is likely to fail for any response containing a HashMap // We don't check for equality or hashCode because it is likely to fail for any response containing a HashMap
Struct struct = response.toStruct((short) version); Struct struct = response.toStruct((short) version);
AbstractResponse deserialized = (AbstractResponse) deserialize(response, struct, (short) version); AbstractResponse deserialized = (AbstractResponse) deserialize(response, struct, (short) version);
deserialized.toStruct((short) version); Struct struct2 = deserialized.toStruct((short) version);
} }
private AbstractRequestResponse deserialize(AbstractRequestResponse req, Struct struct, short version) throws NoSuchMethodException, IllegalAccessException, InvocationTargetException { private AbstractRequestResponse deserialize(AbstractRequestResponse req, Struct struct, short version) throws NoSuchMethodException, IllegalAccessException, InvocationTargetException {
@ -975,6 +989,14 @@ public class RequestResponseTest {
return new SaslHandshakeResponse(Errors.NONE, singletonList("GSSAPI")); return new SaslHandshakeResponse(Errors.NONE, singletonList("GSSAPI"));
} }
private SaslAuthenticateRequest createSaslAuthenticateRequest() {
return new SaslAuthenticateRequest(ByteBuffer.wrap(new byte[0]));
}
private SaslAuthenticateResponse createSaslAuthenticateResponse() {
return new SaslAuthenticateResponse(Errors.NONE, null, ByteBuffer.wrap(new byte[0]), Long.MAX_VALUE);
}
private ApiVersionsRequest createApiVersionRequest() { private ApiVersionsRequest createApiVersionRequest() {
return new ApiVersionsRequest.Builder().build(); return new ApiVersionsRequest.Builder().build();
} }

View File

@ -36,6 +36,8 @@ public class TestSecurityConfig extends AbstractConfig {
Importance.MEDIUM, BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS_DOC) Importance.MEDIUM, BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS_DOC)
.define(BrokerSecurityConfigs.PRINCIPAL_BUILDER_CLASS_CONFIG, Type.CLASS, .define(BrokerSecurityConfigs.PRINCIPAL_BUILDER_CLASS_CONFIG, Type.CLASS,
null, Importance.MEDIUM, BrokerSecurityConfigs.PRINCIPAL_BUILDER_CLASS_DOC) null, Importance.MEDIUM, BrokerSecurityConfigs.PRINCIPAL_BUILDER_CLASS_DOC)
.define(BrokerSecurityConfigs.CONNECTIONS_MAX_REAUTH_MS, Type.LONG, 0L, Importance.MEDIUM,
BrokerSecurityConfigs.CONNECTIONS_MAX_REAUTH_MS_DOC)
.withClientSslSupport() .withClientSslSupport()
.withClientSaslSupport(); .withClientSaslSupport();

View File

@ -184,7 +184,7 @@ public class SaslAuthenticatorFailureDelayTest {
String saslMechanism = (String) saslClientConfigs.get(SaslConfigs.SASL_MECHANISM); String saslMechanism = (String) saslClientConfigs.get(SaslConfigs.SASL_MECHANISM);
this.channelBuilder = ChannelBuilders.clientChannelBuilder(securityProtocol, JaasContext.Type.CLIENT, this.channelBuilder = ChannelBuilders.clientChannelBuilder(securityProtocol, JaasContext.Type.CLIENT,
new TestSecurityConfig(clientConfigs), null, saslMechanism, true); new TestSecurityConfig(clientConfigs), null, saslMechanism, time, true);
this.selector = NetworkTestUtils.createSelector(channelBuilder, time); this.selector = NetworkTestUtils.createSelector(channelBuilder, time);
} }

View File

@ -20,16 +20,20 @@ import java.io.IOException;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey; import java.nio.channels.SelectionKey;
import java.nio.charset.StandardCharsets;
import java.security.NoSuchAlgorithmException; import java.security.NoSuchAlgorithmException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Base64;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Random; import java.util.Random;
import java.util.Base64.Encoder;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import javax.security.auth.Subject; import javax.security.auth.Subject;
import javax.security.auth.callback.Callback; import javax.security.auth.callback.Callback;
@ -80,6 +84,12 @@ import org.apache.kafka.common.security.JaasContext;
import org.apache.kafka.common.security.TestSecurityConfig; import org.apache.kafka.common.security.TestSecurityConfig;
import org.apache.kafka.common.security.auth.KafkaPrincipal; import org.apache.kafka.common.security.auth.KafkaPrincipal;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule; import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerTokenCallback;
import org.apache.kafka.common.security.oauthbearer.internals.unsecured.OAuthBearerConfigException;
import org.apache.kafka.common.security.oauthbearer.internals.unsecured.OAuthBearerIllegalTokenException;
import org.apache.kafka.common.security.oauthbearer.internals.unsecured.OAuthBearerUnsecuredJws;
import org.apache.kafka.common.security.oauthbearer.internals.unsecured.OAuthBearerUnsecuredLoginCallbackHandler;
import org.apache.kafka.common.security.plain.PlainLoginModule; import org.apache.kafka.common.security.plain.PlainLoginModule;
import org.apache.kafka.common.security.scram.ScramCredential; import org.apache.kafka.common.security.scram.ScramCredential;
import org.apache.kafka.common.security.scram.internals.ScramCredentialUtils; import org.apache.kafka.common.security.scram.internals.ScramCredentialUtils;
@ -87,6 +97,7 @@ import org.apache.kafka.common.security.scram.internals.ScramFormatter;
import org.apache.kafka.common.security.scram.ScramLoginModule; import org.apache.kafka.common.security.scram.ScramLoginModule;
import org.apache.kafka.common.security.scram.internals.ScramMechanism; import org.apache.kafka.common.security.scram.internals.ScramMechanism;
import org.apache.kafka.common.security.token.delegation.TokenInformation; import org.apache.kafka.common.security.token.delegation.TokenInformation;
import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache;
import org.apache.kafka.common.utils.SecurityUtils; import org.apache.kafka.common.utils.SecurityUtils;
import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
import org.apache.kafka.common.security.authenticator.TestDigestLoginModule.DigestServerCallbackHandler; import org.apache.kafka.common.security.authenticator.TestDigestLoginModule.DigestServerCallbackHandler;
@ -108,6 +119,7 @@ import static org.junit.Assert.fail;
*/ */
public class SaslAuthenticatorTest { public class SaslAuthenticatorTest {
private static final long CONNECTIONS_MAX_REAUTH_MS_VALUE = 100L;
private static final int BUFFER_SIZE = 4 * 1024; private static final int BUFFER_SIZE = 4 * 1024;
private static Time time = Time.SYSTEM; private static Time time = Time.SYSTEM;
@ -142,6 +154,7 @@ public class SaslAuthenticatorTest {
/** /**
* Tests good path SASL/PLAIN client and server channels using SSL transport layer. * Tests good path SASL/PLAIN client and server channels using SSL transport layer.
* Also tests successful re-authentication.
*/ */
@Test @Test
public void testValidSaslPlainOverSsl() throws Exception { public void testValidSaslPlainOverSsl() throws Exception {
@ -150,12 +163,12 @@ public class SaslAuthenticatorTest {
configureMechanisms("PLAIN", Arrays.asList("PLAIN")); configureMechanisms("PLAIN", Arrays.asList("PLAIN"));
server = createEchoServer(securityProtocol); server = createEchoServer(securityProtocol);
createAndCheckClientConnection(securityProtocol, node); checkAuthenticationAndReauthentication(securityProtocol, node);
server.verifyAuthenticationMetrics(1, 0);
} }
/** /**
* Tests good path SASL/PLAIN client and server channels using PLAINTEXT transport layer. * Tests good path SASL/PLAIN client and server channels using PLAINTEXT transport layer.
* Also tests successful re-authentication.
*/ */
@Test @Test
public void testValidSaslPlainOverPlaintext() throws Exception { public void testValidSaslPlainOverPlaintext() throws Exception {
@ -164,8 +177,7 @@ public class SaslAuthenticatorTest {
configureMechanisms("PLAIN", Arrays.asList("PLAIN")); configureMechanisms("PLAIN", Arrays.asList("PLAIN"));
server = createEchoServer(securityProtocol); server = createEchoServer(securityProtocol);
createAndCheckClientConnection(securityProtocol, node); checkAuthenticationAndReauthentication(securityProtocol, node);
server.verifyAuthenticationMetrics(1, 0);
} }
/** /**
@ -182,6 +194,7 @@ public class SaslAuthenticatorTest {
createAndCheckClientAuthenticationFailure(securityProtocol, node, "PLAIN", createAndCheckClientAuthenticationFailure(securityProtocol, node, "PLAIN",
"Authentication failed: Invalid username or password"); "Authentication failed: Invalid username or password");
server.verifyAuthenticationMetrics(0, 1); server.verifyAuthenticationMetrics(0, 1);
server.verifyReauthenticationMetrics(0, 0);
} }
/** /**
@ -198,6 +211,7 @@ public class SaslAuthenticatorTest {
createAndCheckClientAuthenticationFailure(securityProtocol, node, "PLAIN", createAndCheckClientAuthenticationFailure(securityProtocol, node, "PLAIN",
"Authentication failed: Invalid username or password"); "Authentication failed: Invalid username or password");
server.verifyAuthenticationMetrics(0, 1); server.verifyAuthenticationMetrics(0, 1);
server.verifyReauthenticationMetrics(0, 0);
} }
/** /**
@ -263,6 +277,7 @@ public class SaslAuthenticatorTest {
/** /**
* Tests that servers supporting multiple SASL mechanisms work with clients using * Tests that servers supporting multiple SASL mechanisms work with clients using
* any of the enabled mechanisms. * any of the enabled mechanisms.
* Also tests successful re-authentication over multiple mechanisms.
*/ */
@Test @Test
public void testMultipleServerMechanisms() throws Exception { public void testMultipleServerMechanisms() throws Exception {
@ -275,23 +290,53 @@ public class SaslAuthenticatorTest {
String node1 = "1"; String node1 = "1";
saslClientConfigs.put(SaslConfigs.SASL_MECHANISM, "PLAIN"); saslClientConfigs.put(SaslConfigs.SASL_MECHANISM, "PLAIN");
createAndCheckClientConnection(securityProtocol, node1); createAndCheckClientConnection(securityProtocol, node1);
server.verifyAuthenticationMetrics(1, 0);
String node2 = "2"; Selector selector2 = null;
saslClientConfigs.put(SaslConfigs.SASL_MECHANISM, "DIGEST-MD5"); Selector selector3 = null;
createSelector(securityProtocol, saslClientConfigs); try {
InetSocketAddress addr = new InetSocketAddress("127.0.0.1", server.port()); String node2 = "2";
selector.connect(node2, addr, BUFFER_SIZE, BUFFER_SIZE); saslClientConfigs.put(SaslConfigs.SASL_MECHANISM, "DIGEST-MD5");
NetworkTestUtils.checkClientConnection(selector, node2, 100, 10); createSelector(securityProtocol, saslClientConfigs);
selector2 = selector;
InetSocketAddress addr = new InetSocketAddress("127.0.0.1", server.port());
selector.connect(node2, addr, BUFFER_SIZE, BUFFER_SIZE);
NetworkTestUtils.checkClientConnection(selector, node2, 100, 10);
selector = null; // keeps it from being closed when next one is created
server.verifyAuthenticationMetrics(2, 0);
String node3 = "3"; String node3 = "3";
saslClientConfigs.put(SaslConfigs.SASL_MECHANISM, "SCRAM-SHA-256"); saslClientConfigs.put(SaslConfigs.SASL_MECHANISM, "SCRAM-SHA-256");
createSelector(securityProtocol, saslClientConfigs); createSelector(securityProtocol, saslClientConfigs);
selector.connect(node3, new InetSocketAddress("127.0.0.1", server.port()), BUFFER_SIZE, BUFFER_SIZE); selector3 = selector;
NetworkTestUtils.checkClientConnection(selector, node3, 100, 10); selector.connect(node3, new InetSocketAddress("127.0.0.1", server.port()), BUFFER_SIZE, BUFFER_SIZE);
NetworkTestUtils.checkClientConnection(selector, node3, 100, 10);
server.verifyAuthenticationMetrics(3, 0);
/*
* Now re-authenticate the connections. First we have to sleep long enough so
* that the next write will cause re-authentication, which we expect to succeed.
*/
delay((long) (CONNECTIONS_MAX_REAUTH_MS_VALUE * 1.1));
server.verifyReauthenticationMetrics(0, 0);
NetworkTestUtils.checkClientConnection(selector2, node2, 100, 10);
server.verifyReauthenticationMetrics(1, 0);
NetworkTestUtils.checkClientConnection(selector3, node3, 100, 10);
server.verifyReauthenticationMetrics(2, 0);
} finally {
if (selector2 != null)
selector2.close();
if (selector3 != null)
selector3.close();
}
} }
/** /**
* Tests good path SASL/SCRAM-SHA-256 client and server channels. * Tests good path SASL/SCRAM-SHA-256 client and server channels.
* Also tests successful re-authentication.
*/ */
@Test @Test
public void testValidSaslScramSha256() throws Exception { public void testValidSaslScramSha256() throws Exception {
@ -300,8 +345,7 @@ public class SaslAuthenticatorTest {
server = createEchoServer(securityProtocol); server = createEchoServer(securityProtocol);
updateScramCredentialCache(TestJaasConfig.USERNAME, TestJaasConfig.PASSWORD); updateScramCredentialCache(TestJaasConfig.USERNAME, TestJaasConfig.PASSWORD);
createAndCheckClientConnection(securityProtocol, "0"); checkAuthenticationAndReauthentication(securityProtocol, "0");
server.verifyAuthenticationMetrics(1, 0);
} }
/** /**
@ -338,6 +382,7 @@ public class SaslAuthenticatorTest {
updateScramCredentialCache(TestJaasConfig.USERNAME, TestJaasConfig.PASSWORD); updateScramCredentialCache(TestJaasConfig.USERNAME, TestJaasConfig.PASSWORD);
createAndCheckClientAuthenticationFailure(securityProtocol, node, "SCRAM-SHA-256", null); createAndCheckClientAuthenticationFailure(securityProtocol, node, "SCRAM-SHA-256", null);
server.verifyAuthenticationMetrics(0, 1); server.verifyAuthenticationMetrics(0, 1);
server.verifyReauthenticationMetrics(0, 0);
} }
/** /**
@ -357,6 +402,7 @@ public class SaslAuthenticatorTest {
updateScramCredentialCache(TestJaasConfig.USERNAME, TestJaasConfig.PASSWORD); updateScramCredentialCache(TestJaasConfig.USERNAME, TestJaasConfig.PASSWORD);
createAndCheckClientAuthenticationFailure(securityProtocol, node, "SCRAM-SHA-256", null); createAndCheckClientAuthenticationFailure(securityProtocol, node, "SCRAM-SHA-256", null);
server.verifyAuthenticationMetrics(0, 1); server.verifyAuthenticationMetrics(0, 1);
server.verifyReauthenticationMetrics(0, 0);
} }
/** /**
@ -379,6 +425,7 @@ public class SaslAuthenticatorTest {
saslClientConfigs.put(SaslConfigs.SASL_MECHANISM, "SCRAM-SHA-512"); saslClientConfigs.put(SaslConfigs.SASL_MECHANISM, "SCRAM-SHA-512");
createAndCheckClientConnection(securityProtocol, "2"); createAndCheckClientConnection(securityProtocol, "2");
server.verifyAuthenticationMetrics(1, 1); server.verifyAuthenticationMetrics(1, 1);
server.verifyReauthenticationMetrics(0, 0);
} }
/** /**
@ -420,6 +467,7 @@ public class SaslAuthenticatorTest {
//Check invalid tokenId/tokenInfo in tokenCache //Check invalid tokenId/tokenInfo in tokenCache
createAndCheckClientConnectionFailure(securityProtocol, "0"); createAndCheckClientConnectionFailure(securityProtocol, "0");
server.verifyAuthenticationMetrics(0, 1);
//Check valid token Info and invalid credentials //Check valid token Info and invalid credentials
KafkaPrincipal owner = SecurityUtils.parseKafkaPrincipal("User:Owner"); KafkaPrincipal owner = SecurityUtils.parseKafkaPrincipal("User:Owner");
@ -428,10 +476,74 @@ public class SaslAuthenticatorTest {
System.currentTimeMillis(), System.currentTimeMillis(), System.currentTimeMillis()); System.currentTimeMillis(), System.currentTimeMillis(), System.currentTimeMillis());
server.tokenCache().addToken(tokenId, tokenInfo); server.tokenCache().addToken(tokenId, tokenInfo);
createAndCheckClientConnectionFailure(securityProtocol, "0"); createAndCheckClientConnectionFailure(securityProtocol, "0");
server.verifyAuthenticationMetrics(0, 2);
//Check with valid token Info and credentials //Check with valid token Info and credentials
updateTokenCredentialCache(tokenId, tokenHmac); updateTokenCredentialCache(tokenId, tokenHmac);
createAndCheckClientConnection(securityProtocol, "0"); createAndCheckClientConnection(securityProtocol, "0");
server.verifyAuthenticationMetrics(1, 2);
server.verifyReauthenticationMetrics(0, 0);
}
@Test
public void testTokenReauthenticationOverSaslScram() throws Exception {
SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL;
TestJaasConfig jaasConfig = configureMechanisms("SCRAM-SHA-256", Arrays.asList("SCRAM-SHA-256"));
// create jaas config for token auth
Map<String, Object> options = new HashMap<>();
String tokenId = "token1";
String tokenHmac = "abcdefghijkl";
options.put("username", tokenId); // tokenId
options.put("password", tokenHmac); // token hmac
options.put(ScramLoginModule.TOKEN_AUTH_CONFIG, "true"); // enable token authentication
jaasConfig.createOrUpdateEntry(TestJaasConfig.LOGIN_CONTEXT_CLIENT, ScramLoginModule.class.getName(), options);
// ensure re-authentication based on token expiry rather than a default value
saslServerConfigs.put(BrokerSecurityConfigs.CONNECTIONS_MAX_REAUTH_MS, Long.MAX_VALUE);
/*
* create a token cache that adjusts the token expiration dynamically so that
* the first time the expiry is read during authentication we use it to define a
* session expiration time that we can then sleep through; then the second time
* the value is read (during re-authentication) it will be in the future.
*/
Function<Integer, Long> tokenLifetime = callNum -> 10 * callNum * CONNECTIONS_MAX_REAUTH_MS_VALUE;
DelegationTokenCache tokenCache = new DelegationTokenCache(ScramMechanism.mechanismNames()) {
int callNum = 0;
@Override
public TokenInformation token(String tokenId) {
TokenInformation baseTokenInfo = super.token(tokenId);
long thisLifetimeMs = System.currentTimeMillis() + tokenLifetime.apply(++callNum).longValue();
TokenInformation retvalTokenInfo = new TokenInformation(baseTokenInfo.tokenId(), baseTokenInfo.owner(),
baseTokenInfo.renewers(), baseTokenInfo.issueTimestamp(), thisLifetimeMs, thisLifetimeMs);
return retvalTokenInfo;
}
};
server = createEchoServer(ListenerName.forSecurityProtocol(securityProtocol), securityProtocol, tokenCache);
KafkaPrincipal owner = SecurityUtils.parseKafkaPrincipal("User:Owner");
KafkaPrincipal renewer = SecurityUtils.parseKafkaPrincipal("User:Renewer1");
TokenInformation tokenInfo = new TokenInformation(tokenId, owner, Collections.singleton(renewer),
System.currentTimeMillis(), System.currentTimeMillis(), System.currentTimeMillis());
server.tokenCache().addToken(tokenId, tokenInfo);
updateTokenCredentialCache(tokenId, tokenHmac);
// initial authentication must succeed
createClientConnection(securityProtocol, "0");
checkClientConnection("0");
// ensure metrics are as expected before trying to re-authenticate
server.verifyAuthenticationMetrics(1, 0);
server.verifyReauthenticationMetrics(0, 0);
/*
* Now re-authenticate and ensure it succeeds. We have to sleep long enough so
* that the current delegation token will be expired when the next write occurs;
* this will trigger a re-authentication. Then the second time the delegation
* token is read and transmitted to the server it will again have an expiration
* date in the future.
*/
delay(tokenLifetime.apply(1));
checkClientConnection("0");
server.verifyReauthenticationMetrics(1, 0);
} }
/** /**
@ -916,6 +1028,7 @@ public class SaslAuthenticatorTest {
server = createEchoServer(securityProtocol); server = createEchoServer(securityProtocol);
createAndCheckClientConnectionFailure(securityProtocol, node); createAndCheckClientConnectionFailure(securityProtocol, node);
server.verifyAuthenticationMetrics(0, 1); server.verifyAuthenticationMetrics(0, 1);
server.verifyReauthenticationMetrics(0, 0);
} }
/** /**
@ -931,6 +1044,7 @@ public class SaslAuthenticatorTest {
server = createEchoServer(securityProtocol); server = createEchoServer(securityProtocol);
createAndCheckClientConnectionFailure(securityProtocol, node); createAndCheckClientConnectionFailure(securityProtocol, node);
server.verifyAuthenticationMetrics(0, 1); server.verifyAuthenticationMetrics(0, 1);
server.verifyReauthenticationMetrics(0, 0);
} }
/** /**
@ -1207,7 +1321,164 @@ public class SaslAuthenticatorTest {
server = createEchoServer(securityProtocol); server = createEchoServer(securityProtocol);
createAndCheckClientConnection(securityProtocol, node); createAndCheckClientConnection(securityProtocol, node);
} }
/**
* Re-authentication must fail if principal changes
*/
@Test
public void testCannotReauthenticateWithDifferentPrincipal() throws Exception {
String node = "0";
SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL;
saslClientConfigs.put(SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS,
AlternateLoginCallbackHandler.class.getName());
configureMechanisms(OAuthBearerLoginModule.OAUTHBEARER_MECHANISM,
Arrays.asList(OAuthBearerLoginModule.OAUTHBEARER_MECHANISM));
server = createEchoServer(securityProtocol);
// initial authentication must succeed
createClientConnection(securityProtocol, node);
checkClientConnection(node);
// ensure metrics are as expected before trying to re-authenticate
server.verifyAuthenticationMetrics(1, 0);
server.verifyReauthenticationMetrics(0, 0);
/*
* Now re-authenticate with a different principal and ensure it fails. We first
* have to sleep long enough for the background refresh thread to replace the
* original token with a new one.
*/
delay(1000L);
try {
checkClientConnection(node);
fail("Re-authentication with a different principal should have failed but did not");
} catch (AssertionError e) {
// ignore, expected
server.verifyReauthenticationMetrics(0, 1);
}
}
/**
* Re-authentication must fail if mechanism changes
*/
@Test
public void testCannotReauthenticateWithDifferentMechanism() throws Exception {
String node = "0";
SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL;
configureMechanisms("DIGEST-MD5", Arrays.asList("DIGEST-MD5", "PLAIN"));
configureDigestMd5ServerCallback(securityProtocol);
server = createEchoServer(securityProtocol);
String saslMechanism = (String) saslClientConfigs.get(SaslConfigs.SASL_MECHANISM);
Map<String, ?> configs = new TestSecurityConfig(saslClientConfigs).values();
this.channelBuilder = new AlternateSaslChannelBuilder(Mode.CLIENT,
Collections.singletonMap(saslMechanism, JaasContext.loadClientContext(configs)), securityProtocol, null,
false, saslMechanism, true, credentialCache, null, time);
this.channelBuilder.configure(configs);
// initial authentication must succeed
this.selector = NetworkTestUtils.createSelector(channelBuilder, time);
InetSocketAddress addr = new InetSocketAddress("localhost", server.port());
selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
checkClientConnection(node);
// ensure metrics are as expected before trying to re-authenticate
server.verifyAuthenticationMetrics(1, 0);
server.verifyReauthenticationMetrics(0, 0);
/*
* Now re-authenticate with a different mechanism and ensure it fails. We have
* to sleep long enough so that the next write will trigger a re-authentication.
*/
delay((long) (CONNECTIONS_MAX_REAUTH_MS_VALUE * 1.1));
try {
checkClientConnection(node);
fail("Re-authentication with a different mechanism should have failed but did not");
} catch (AssertionError e) {
// ignore, expected
server.verifyAuthenticationMetrics(1, 0);
server.verifyReauthenticationMetrics(0, 1);
}
}
/**
* Second re-authentication must fail if it is sooner than one second after the first
*/
@Test
public void testCannotReauthenticateAgainFasterThanOneSecond() throws Exception {
String node = "0";
SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL;
configureMechanisms(OAuthBearerLoginModule.OAUTHBEARER_MECHANISM,
Arrays.asList(OAuthBearerLoginModule.OAUTHBEARER_MECHANISM));
server = createEchoServer(securityProtocol);
try {
createClientConnection(securityProtocol, node);
checkClientConnection(node);
server.verifyAuthenticationMetrics(1, 0);
server.verifyReauthenticationMetrics(0, 0);
/*
* Now sleep long enough so that the next write will cause re-authentication,
* which we expect to succeed.
*/
delay((long) (CONNECTIONS_MAX_REAUTH_MS_VALUE * 1.1));
checkClientConnection(node);
server.verifyAuthenticationMetrics(1, 0);
server.verifyReauthenticationMetrics(1, 0);
/*
* Now sleep long enough so that the next write will cause re-authentication,
* but this time we expect re-authentication to not occur since it has been too
* soon. The checkClientConnection() call should return an error saying it
* expected the one byte-plus-node response but got the SaslHandshakeRequest
* instead
*/
delay((long) (CONNECTIONS_MAX_REAUTH_MS_VALUE * 1.1));
NetworkTestUtils.checkClientConnection(selector, node, 1, 1);
fail("Expected a failure when trying to re-authenticate to quickly, but that did not occur");
} catch (AssertionError e) {
String expectedResponseTextRegex = "\\w-" + node;
String receivedResponseTextRegex = ".*" + OAuthBearerLoginModule.OAUTHBEARER_MECHANISM;
assertTrue(
"Should have received the SaslHandshakeRequest bytes back since we re-authenticated too quickly, but instead we got our generated message echoed back, implying re-auth succeeded when it should not have",
e.getMessage().matches(
".*\\<\\[" + expectedResponseTextRegex + "]>.*\\<\\[" + receivedResponseTextRegex + "]>"));
server.verifyReauthenticationMetrics(1, 0); // unchanged
} finally {
selector.close();
selector = null;
}
}
/**
* Tests good path SASL/PLAIN client and server channels using SSL transport layer.
* Repeatedly tests successful re-authentication over several seconds.
*/
@Test
public void testRepeatedValidSaslPlainOverSsl() throws Exception {
String node = "0";
SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL;
configureMechanisms("PLAIN", Arrays.asList("PLAIN"));
/*
* Make sure 85% of this value is at least 1 second otherwise it is possible for
* the client to start re-authenticating but the server does not start due to
* the 1-second minimum. If this happens the SASL HANDSHAKE request that was
* injected to start re-authentication will be echoed back to the client instead
* of the data that the client explicitly sent, and then the client will not
* recognize that data and will throw an assertion error.
*/
saslServerConfigs.put(BrokerSecurityConfigs.CONNECTIONS_MAX_REAUTH_MS,
new Double(1.1 * 1000L / 0.85).longValue());
server = createEchoServer(securityProtocol);
createClientConnection(securityProtocol, node);
checkClientConnection(node);
server.verifyAuthenticationMetrics(1, 0);
server.verifyReauthenticationMetrics(0, 0);
double successfulReauthentications = 0;
int desiredNumReauthentications = 5;
long startMs = Time.SYSTEM.milliseconds();
long timeoutMs = startMs + 1000 * 15; // stop after 15 seconds
while (successfulReauthentications < desiredNumReauthentications
&& Time.SYSTEM.milliseconds() < timeoutMs) {
checkClientConnection(node);
successfulReauthentications = server.metricValue("successful-reauthentication-total");
}
server.verifyReauthenticationMetrics(desiredNumReauthentications, 0);
}
/** /**
* Tests OAUTHBEARER client channels without tokens for the server. * Tests OAUTHBEARER client channels without tokens for the server.
*/ */
@ -1313,15 +1584,16 @@ public class SaslAuthenticatorTest {
if (isScram) if (isScram)
ScramCredentialUtils.createCache(credentialCache, Arrays.asList(saslMechanism)); ScramCredentialUtils.createCache(credentialCache, Arrays.asList(saslMechanism));
SaslChannelBuilder serverChannelBuilder = new SaslChannelBuilder(Mode.SERVER, jaasContexts, SaslChannelBuilder serverChannelBuilder = new SaslChannelBuilder(Mode.SERVER, jaasContexts,
securityProtocol, listenerName, false, saslMechanism, true, credentialCache, null) { securityProtocol, listenerName, false, saslMechanism, true, credentialCache, null, time) {
@Override @Override
protected SaslServerAuthenticator buildServerAuthenticator(Map<String, ?> configs, protected SaslServerAuthenticator buildServerAuthenticator(Map<String, ?> configs,
Map<String, AuthenticateCallbackHandler> callbackHandlers, Map<String, AuthenticateCallbackHandler> callbackHandlers,
String id, String id,
TransportLayer transportLayer, TransportLayer transportLayer,
Map<String, Subject> subjects) throws IOException { Map<String, Subject> subjects,
return new SaslServerAuthenticator(configs, callbackHandlers, id, subjects, null, listenerName, securityProtocol, transportLayer) { Map<String, Long> connectionsMaxReauthMsByMechanism) {
return new SaslServerAuthenticator(configs, callbackHandlers, id, subjects, null, listenerName, securityProtocol, transportLayer, connectionsMaxReauthMsByMechanism, time) {
@Override @Override
protected ApiVersionsResponse apiVersionsResponse() { protected ApiVersionsResponse apiVersionsResponse() {
@ -1359,7 +1631,7 @@ public class SaslAuthenticatorTest {
final Map<String, JaasContext> jaasContexts = Collections.singletonMap(saslMechanism, jaasContext); final Map<String, JaasContext> jaasContexts = Collections.singletonMap(saslMechanism, jaasContext);
SaslChannelBuilder clientChannelBuilder = new SaslChannelBuilder(Mode.CLIENT, jaasContexts, SaslChannelBuilder clientChannelBuilder = new SaslChannelBuilder(Mode.CLIENT, jaasContexts,
securityProtocol, listenerName, false, saslMechanism, true, null, null) { securityProtocol, listenerName, false, saslMechanism, true, null, null, time) {
@Override @Override
protected SaslClientAuthenticator buildClientAuthenticator(Map<String, ?> configs, protected SaslClientAuthenticator buildClientAuthenticator(Map<String, ?> configs,
@ -1368,16 +1640,16 @@ public class SaslAuthenticatorTest {
String serverHost, String serverHost,
String servicePrincipal, String servicePrincipal,
TransportLayer transportLayer, TransportLayer transportLayer,
Subject subject) throws IOException { Subject subject) {
return new SaslClientAuthenticator(configs, callbackHandler, id, subject, return new SaslClientAuthenticator(configs, callbackHandler, id, subject,
servicePrincipal, serverHost, saslMechanism, true, transportLayer) { servicePrincipal, serverHost, saslMechanism, true, transportLayer, time) {
@Override @Override
protected SaslHandshakeRequest createSaslHandshakeRequest(short version) { protected SaslHandshakeRequest createSaslHandshakeRequest(short version) {
return new SaslHandshakeRequest.Builder(saslMechanism).build((short) 0); return new SaslHandshakeRequest.Builder(saslMechanism).build((short) 0);
} }
@Override @Override
protected void saslAuthenticateVersion(short version) { protected void saslAuthenticateVersion(ApiVersionsResponse apiVersionsResponse) {
// Don't set version so that headers are disabled // Don't set version so that headers are disabled
} }
}; };
@ -1467,6 +1739,7 @@ public class SaslAuthenticatorTest {
private TestJaasConfig configureMechanisms(String clientMechanism, List<String> serverMechanisms) { private TestJaasConfig configureMechanisms(String clientMechanism, List<String> serverMechanisms) {
saslClientConfigs.put(SaslConfigs.SASL_MECHANISM, clientMechanism); saslClientConfigs.put(SaslConfigs.SASL_MECHANISM, clientMechanism);
saslServerConfigs.put(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG, serverMechanisms); saslServerConfigs.put(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG, serverMechanisms);
saslServerConfigs.put(BrokerSecurityConfigs.CONNECTIONS_MAX_REAUTH_MS, CONNECTIONS_MAX_REAUTH_MS_VALUE);
if (serverMechanisms.contains("DIGEST-MD5")) { if (serverMechanisms.contains("DIGEST-MD5")) {
saslServerConfigs.put("digest-md5." + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS, saslServerConfigs.put("digest-md5." + BrokerSecurityConfigs.SASL_SERVER_CALLBACK_HANDLER_CLASS,
TestDigestLoginModule.DigestServerCallbackHandler.class.getName()); TestDigestLoginModule.DigestServerCallbackHandler.class.getName());
@ -1488,7 +1761,7 @@ public class SaslAuthenticatorTest {
String saslMechanism = (String) saslClientConfigs.get(SaslConfigs.SASL_MECHANISM); String saslMechanism = (String) saslClientConfigs.get(SaslConfigs.SASL_MECHANISM);
this.channelBuilder = ChannelBuilders.clientChannelBuilder(securityProtocol, JaasContext.Type.CLIENT, this.channelBuilder = ChannelBuilders.clientChannelBuilder(securityProtocol, JaasContext.Type.CLIENT,
new TestSecurityConfig(clientConfigs), null, saslMechanism, true); new TestSecurityConfig(clientConfigs), null, saslMechanism, time, true);
this.selector = NetworkTestUtils.createSelector(channelBuilder, time); this.selector = NetworkTestUtils.createSelector(channelBuilder, time);
} }
@ -1501,17 +1774,39 @@ public class SaslAuthenticatorTest {
new TestSecurityConfig(saslServerConfigs), credentialCache, time); new TestSecurityConfig(saslServerConfigs), credentialCache, time);
} }
private NioEchoServer createEchoServer(ListenerName listenerName, SecurityProtocol securityProtocol,
DelegationTokenCache tokenCache) throws Exception {
return NetworkTestUtils.createEchoServer(listenerName, securityProtocol,
new TestSecurityConfig(saslServerConfigs), credentialCache, 100, time, tokenCache);
}
private void createClientConnection(SecurityProtocol securityProtocol, String node) throws Exception { private void createClientConnection(SecurityProtocol securityProtocol, String node) throws Exception {
createSelector(securityProtocol, saslClientConfigs); createSelector(securityProtocol, saslClientConfigs);
InetSocketAddress addr = new InetSocketAddress("localhost", server.port()); InetSocketAddress addr = new InetSocketAddress("localhost", server.port());
selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE); selector.connect(node, addr, BUFFER_SIZE, BUFFER_SIZE);
} }
private void createAndCheckClientConnection(SecurityProtocol securityProtocol, String node) throws Exception { private void checkClientConnection(String node) throws Exception {
createClientConnection(securityProtocol, node);
NetworkTestUtils.checkClientConnection(selector, node, 100, 10); NetworkTestUtils.checkClientConnection(selector, node, 100, 10);
selector.close(); }
selector = null;
private void closeClientConnectionIfNecessary() throws Exception {
if (selector != null) {
selector.close();
selector = null;
}
}
/*
* Also closes the connection after creating/checking it
*/
private void createAndCheckClientConnection(SecurityProtocol securityProtocol, String node) throws Exception {
try {
createClientConnection(securityProtocol, node);
checkClientConnection(node);
} finally {
closeClientConnectionIfNecessary();
}
} }
private void createAndCheckClientAuthenticationFailure(SecurityProtocol securityProtocol, String node, private void createAndCheckClientAuthenticationFailure(SecurityProtocol securityProtocol, String node,
@ -1519,18 +1814,47 @@ public class SaslAuthenticatorTest {
ChannelState finalState = createAndCheckClientConnectionFailure(securityProtocol, node); ChannelState finalState = createAndCheckClientConnectionFailure(securityProtocol, node);
Exception exception = finalState.exception(); Exception exception = finalState.exception();
assertTrue("Invalid exception class " + exception.getClass(), exception instanceof SaslAuthenticationException); assertTrue("Invalid exception class " + exception.getClass(), exception instanceof SaslAuthenticationException);
if (expectedErrorMessage == null) if (expectedErrorMessage != null)
expectedErrorMessage = "Authentication failed due to invalid credentials with SASL mechanism " + mechanism; // check for full equality
assertEquals(expectedErrorMessage, exception.getMessage()); assertEquals(expectedErrorMessage, exception.getMessage());
else {
String expectedErrorMessagePrefix = "Authentication failed during authentication due to invalid credentials with SASL mechanism "
+ mechanism + ": ";
if (exception.getMessage().startsWith(expectedErrorMessagePrefix))
return;
// we didn't match a recognized error message, so fail
fail("Incorrect failure message: " + exception.getMessage());
}
} }
private ChannelState createAndCheckClientConnectionFailure(SecurityProtocol securityProtocol, String node) private ChannelState createAndCheckClientConnectionFailure(SecurityProtocol securityProtocol, String node)
throws Exception { throws Exception {
createClientConnection(securityProtocol, node); try {
ChannelState finalState = NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED); createClientConnection(securityProtocol, node);
selector.close(); ChannelState finalState = NetworkTestUtils.waitForChannelClose(selector, node, ChannelState.State.AUTHENTICATION_FAILED);
selector = null; return finalState;
return finalState; } finally {
closeClientConnectionIfNecessary();
}
}
private void checkAuthenticationAndReauthentication(SecurityProtocol securityProtocol, String node)
throws Exception, InterruptedException {
try {
createClientConnection(securityProtocol, node);
checkClientConnection(node);
server.verifyAuthenticationMetrics(1, 0);
/*
* Now re-authenticate the connection. First we have to sleep long enough so
* that the next write will cause re-authentication, which we expect to succeed.
*/
delay((long) (CONNECTIONS_MAX_REAUTH_MS_VALUE * 1.1));
server.verifyReauthenticationMetrics(0, 0);
checkClientConnection(node);
server.verifyReauthenticationMetrics(1, 0);
} finally {
closeClientConnectionIfNecessary();
}
} }
private AbstractResponse sendKafkaRequestReceiveResponse(String node, ApiKeys apiKey, AbstractRequest request) throws IOException { private AbstractResponse sendKafkaRequestReceiveResponse(String node, ApiKeys apiKey, AbstractRequest request) throws IOException {
@ -1604,6 +1928,7 @@ public class SaslAuthenticatorTest {
return new ApiVersionsRequest.Builder((short) 0).build(); return new ApiVersionsRequest.Builder((short) 0).build();
} }
@SuppressWarnings("unchecked")
private void updateTokenCredentialCache(String username, String password) throws NoSuchAlgorithmException { private void updateTokenCredentialCache(String username, String password) throws NoSuchAlgorithmException {
for (String mechanism : (List<String>) saslServerConfigs.get(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG)) { for (String mechanism : (List<String>) saslServerConfigs.get(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG)) {
ScramMechanism scramMechanism = ScramMechanism.forMechanismName(mechanism); ScramMechanism scramMechanism = ScramMechanism.forMechanismName(mechanism);
@ -1615,6 +1940,12 @@ public class SaslAuthenticatorTest {
} }
} }
private static void delay(long delayMillis) throws InterruptedException {
final long startTime = System.currentTimeMillis();
while ((System.currentTimeMillis() - startTime) < delayMillis)
Thread.sleep(CONNECTIONS_MAX_REAUTH_MS_VALUE / 5);
}
public static class TestClientCallbackHandler implements AuthenticateCallbackHandler { public static class TestClientCallbackHandler implements AuthenticateCallbackHandler {
static final String USERNAME = "TestClientCallbackHandler-user"; static final String USERNAME = "TestClientCallbackHandler-user";
@ -1731,4 +2062,114 @@ public class SaslAuthenticatorTest {
} }
} }
} }
/*
* Create an alternate login callback handler that continually returns a
* different principal
*/
public static class AlternateLoginCallbackHandler implements AuthenticateCallbackHandler {
private static final OAuthBearerUnsecuredLoginCallbackHandler DELEGATE = new OAuthBearerUnsecuredLoginCallbackHandler();
private static final String QUOTE = "\"";
private static int numInvocations = 0;
@Override
public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
DELEGATE.handle(callbacks);
// now change any returned token to have a different principal name
if (callbacks.length > 0)
for (Callback callback : callbacks) {
if (callback instanceof OAuthBearerTokenCallback) {
OAuthBearerTokenCallback oauthBearerTokenCallback = (OAuthBearerTokenCallback) callback;
OAuthBearerToken token = oauthBearerTokenCallback.token();
if (token != null) {
String changedPrincipalNameToUse = token.principalName()
+ String.valueOf(++numInvocations);
String headerJson = "{" + claimOrHeaderJsonText("alg", "none") + "}";
/*
* Use a short lifetime so the background refresh thread replaces it before we
* re-authenticate
*/
String lifetimeSecondsValueToUse = "1";
String claimsJson;
try {
claimsJson = String.format("{%s,%s,%s}",
expClaimText(Long.parseLong(lifetimeSecondsValueToUse)),
claimOrHeaderJsonText("iat", time.milliseconds() / 1000.0),
claimOrHeaderJsonText("sub", changedPrincipalNameToUse));
} catch (NumberFormatException e) {
throw new OAuthBearerConfigException(e.getMessage());
}
try {
Encoder urlEncoderNoPadding = Base64.getUrlEncoder().withoutPadding();
OAuthBearerUnsecuredJws jws = new OAuthBearerUnsecuredJws(String.format("%s.%s.",
urlEncoderNoPadding.encodeToString(headerJson.getBytes(StandardCharsets.UTF_8)),
urlEncoderNoPadding
.encodeToString(claimsJson.getBytes(StandardCharsets.UTF_8))),
"sub", "scope");
oauthBearerTokenCallback.token(jws);
} catch (OAuthBearerIllegalTokenException e) {
// occurs if the principal claim doesn't exist or has an empty value
throw new OAuthBearerConfigException(e.getMessage(), e);
}
}
}
}
}
private static String claimOrHeaderJsonText(String claimName, String claimValue) {
return QUOTE + claimName + QUOTE + ":" + QUOTE + claimValue + QUOTE;
}
private static String claimOrHeaderJsonText(String claimName, Number claimValue) {
return QUOTE + claimName + QUOTE + ":" + claimValue;
}
private static String expClaimText(long lifetimeSeconds) {
return claimOrHeaderJsonText("exp", time.milliseconds() / 1000.0 + lifetimeSeconds);
}
@Override
public void configure(Map<String, ?> configs, String saslMechanism,
List<AppConfigurationEntry> jaasConfigEntries) {
DELEGATE.configure(configs, saslMechanism, jaasConfigEntries);
}
@Override
public void close() {
DELEGATE.close();
}
}
/*
* Define a channel builder that starts with the DIGEST-MD5 mechanism and then
* switches to the PLAIN mechanism
*/
private static class AlternateSaslChannelBuilder extends SaslChannelBuilder {
private int numInvocations = 0;
public AlternateSaslChannelBuilder(Mode mode, Map<String, JaasContext> jaasContexts,
SecurityProtocol securityProtocol, ListenerName listenerName, boolean isInterBrokerListener,
String clientSaslMechanism, boolean handshakeRequestEnable, CredentialCache credentialCache,
DelegationTokenCache tokenCache, Time time) {
super(mode, jaasContexts, securityProtocol, listenerName, isInterBrokerListener, clientSaslMechanism,
handshakeRequestEnable, credentialCache, tokenCache, time);
}
@Override
protected SaslClientAuthenticator buildClientAuthenticator(Map<String, ?> configs,
AuthenticateCallbackHandler callbackHandler, String id, String serverHost, String servicePrincipal,
TransportLayer transportLayer, Subject subject) {
if (++numInvocations == 1)
return new SaslClientAuthenticator(configs, callbackHandler, id, subject, servicePrincipal, serverHost,
"DIGEST-MD5", true, transportLayer, time);
else
return new SaslClientAuthenticator(configs, callbackHandler, id, subject, servicePrincipal, serverHost,
"PLAIN", true, transportLayer, time) {
@Override
protected SaslHandshakeRequest createSaslHandshakeRequest(short version) {
return new SaslHandshakeRequest.Builder("PLAIN").build(version);
}
};
}
}
} }

View File

@ -28,6 +28,7 @@ import org.apache.kafka.common.protocol.types.Struct;
import org.apache.kafka.common.requests.RequestHeader; import org.apache.kafka.common.requests.RequestHeader;
import org.apache.kafka.common.security.JaasContext; import org.apache.kafka.common.security.JaasContext;
import org.apache.kafka.common.security.plain.PlainLoginModule; import org.apache.kafka.common.security.plain.PlainLoginModule;
import org.apache.kafka.common.utils.Time;
import org.junit.Test; import org.junit.Test;
import javax.security.auth.Subject; import javax.security.auth.Subject;
@ -100,7 +101,7 @@ public class SaslServerAuthenticatorTest {
Map<String, AuthenticateCallbackHandler> callbackHandlers = Collections.singletonMap( Map<String, AuthenticateCallbackHandler> callbackHandlers = Collections.singletonMap(
mechanism, new SaslServerCallbackHandler()); mechanism, new SaslServerCallbackHandler());
return new SaslServerAuthenticator(configs, callbackHandlers, "node", subjects, null, return new SaslServerAuthenticator(configs, callbackHandlers, "node", subjects, null,
new ListenerName("ssl"), SecurityProtocol.SASL_SSL, transportLayer); new ListenerName("ssl"), SecurityProtocol.SASL_SSL, transportLayer, Collections.emptyMap(), Time.SYSTEM);
} }
} }

View File

@ -17,6 +17,7 @@
package org.apache.kafka.common.security.oauthbearer.internals; package org.apache.kafka.common.security.oauthbearer.internals;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.assertNull; import static org.junit.Assert.assertNull;
@ -36,6 +37,7 @@ import org.apache.kafka.common.errors.SaslAuthenticationException;
import org.apache.kafka.common.security.JaasContext; import org.apache.kafka.common.security.JaasContext;
import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
import org.apache.kafka.common.security.auth.SaslExtensions; import org.apache.kafka.common.security.auth.SaslExtensions;
import org.apache.kafka.common.security.authenticator.SaslInternalConfigs;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerExtensionsValidatorCallback; import org.apache.kafka.common.security.oauthbearer.OAuthBearerExtensionsValidatorCallback;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule; import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
@ -103,6 +105,15 @@ public class OAuthBearerSaslServerTest {
assertTrue("Next challenge is not empty", nextChallenge.length == 0); assertTrue("Next challenge is not empty", nextChallenge.length == 0);
} }
@Test
public void negotiatedProperty() throws Exception {
saslServer.evaluateResponse(clientInitialResponse(USER));
OAuthBearerToken token = (OAuthBearerToken) saslServer.getNegotiatedProperty("OAUTHBEARER.token");
assertNotNull(token);
assertEquals(token.lifetimeMs(),
saslServer.getNegotiatedProperty(SaslInternalConfigs.CREDENTIAL_LIFETIME_MS_SASL_NEGOTIATED_PROPERTY_KEY));
}
/** /**
* SASL Extensions that are validated by the callback handler should be accessible through the {@code #getNegotiatedProperty()} method * SASL Extensions that are validated by the callback handler should be accessible through the {@code #getNegotiatedProperty()} method
*/ */

View File

@ -20,6 +20,7 @@ package org.apache.kafka.test;
* Interface to wrap actions that are required to wait until a condition is met * Interface to wrap actions that are required to wait until a condition is met
* for testing purposes. Note that this is not intended to do any assertions. * for testing purposes. Note that this is not intended to do any assertions.
*/ */
@FunctionalInterface
public interface TestCondition { public interface TestCondition {
boolean conditionMet(); boolean conditionMet();

View File

@ -21,7 +21,10 @@ import org.apache.kafka.clients.producer.ProducerConfig;
import org.apache.kafka.common.Cluster; import org.apache.kafka.common.Cluster;
import org.apache.kafka.common.Node; import org.apache.kafka.common.Node;
import org.apache.kafka.common.PartitionInfo; import org.apache.kafka.common.PartitionInfo;
import org.apache.kafka.common.network.NetworkReceive;
import org.apache.kafka.common.protocol.ApiKeys;
import org.apache.kafka.common.protocol.types.Struct; import org.apache.kafka.common.protocol.types.Struct;
import org.apache.kafka.common.requests.RequestHeader;
import org.apache.kafka.common.utils.Utils; import org.apache.kafka.common.utils.Utils;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -46,6 +49,7 @@ import java.util.UUID;
import java.util.regex.Matcher; import java.util.regex.Matcher;
import java.util.regex.Pattern; import java.util.regex.Pattern;
import java.util.concurrent.Future; import java.util.concurrent.Future;
import java.util.function.Supplier;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import static java.util.Arrays.asList; import static java.util.Arrays.asList;
@ -253,7 +257,14 @@ public class TestUtils {
* uses default value of 15 seconds for timeout * uses default value of 15 seconds for timeout
*/ */
public static void waitForCondition(final TestCondition testCondition, final String conditionDetails) throws InterruptedException { public static void waitForCondition(final TestCondition testCondition, final String conditionDetails) throws InterruptedException {
waitForCondition(testCondition, DEFAULT_MAX_WAIT_MS, conditionDetails); waitForCondition(testCondition, DEFAULT_MAX_WAIT_MS, () -> conditionDetails);
}
/**
* uses default value of 15 seconds for timeout
*/
public static void waitForCondition(final TestCondition testCondition, final Supplier<String> conditionDetailsSupplier) throws InterruptedException {
waitForCondition(testCondition, DEFAULT_MAX_WAIT_MS, conditionDetailsSupplier);
} }
/** /**
@ -263,6 +274,16 @@ public class TestUtils {
* avoid transient failures due to slow or overloaded machines. * avoid transient failures due to slow or overloaded machines.
*/ */
public static void waitForCondition(final TestCondition testCondition, final long maxWaitMs, String conditionDetails) throws InterruptedException { public static void waitForCondition(final TestCondition testCondition, final long maxWaitMs, String conditionDetails) throws InterruptedException {
waitForCondition(testCondition, maxWaitMs, () -> conditionDetails);
}
/**
* Wait for condition to be met for at most {@code maxWaitMs} and throw assertion failure otherwise.
* This should be used instead of {@code Thread.sleep} whenever possible as it allows a longer timeout to be used
* without unnecessarily increasing test time (as the condition is checked frequently). The longer timeout is needed to
* avoid transient failures due to slow or overloaded machines.
*/
public static void waitForCondition(final TestCondition testCondition, final long maxWaitMs, Supplier<String> conditionDetailsSupplier) throws InterruptedException {
final long startTime = System.currentTimeMillis(); final long startTime = System.currentTimeMillis();
boolean testConditionMet; boolean testConditionMet;
@ -274,7 +295,8 @@ public class TestUtils {
// could be avoided by making the implementations more robust, but we have a large number of such implementations // could be avoided by making the implementations more robust, but we have a large number of such implementations
// and it's easier to simply avoid the issue altogether) // and it's easier to simply avoid the issue altogether)
if (!testConditionMet) { if (!testConditionMet) {
conditionDetails = conditionDetails != null ? conditionDetails : ""; String conditionDetailsSupplied = conditionDetailsSupplier != null ? conditionDetailsSupplier.get() : null;
String conditionDetails = conditionDetailsSupplied != null ? conditionDetailsSupplied : "";
throw new AssertionError("Condition not met within timeout " + maxWaitMs + ". " + conditionDetails); throw new AssertionError("Condition not met within timeout " + maxWaitMs + ". " + conditionDetails);
} }
} }
@ -356,4 +378,8 @@ public class TestUtils {
exceptionClass, cause.getClass()); exceptionClass, cause.getClass());
} }
} }
public static ApiKeys apiKeyFrom(NetworkReceive networkReceive) {
return RequestHeader.parse(networkReceive.payload().duplicate()).apiKey();
}
} }

View File

@ -101,7 +101,7 @@ public class WorkerGroupMember {
config.getString(CommonClientConfigs.CLIENT_DNS_LOOKUP_CONFIG)); config.getString(CommonClientConfigs.CLIENT_DNS_LOOKUP_CONFIG));
this.metadata.update(Cluster.bootstrap(addresses), Collections.<String>emptySet(), 0); this.metadata.update(Cluster.bootstrap(addresses), Collections.<String>emptySet(), 0);
String metricGrpPrefix = "connect"; String metricGrpPrefix = "connect";
ChannelBuilder channelBuilder = ClientUtils.createChannelBuilder(config); ChannelBuilder channelBuilder = ClientUtils.createChannelBuilder(config, time);
NetworkClient netClient = new NetworkClient( NetworkClient netClient = new NetworkClient(
new Selector(config.getLong(CommonClientConfigs.CONNECTIONS_MAX_IDLE_MS_CONFIG), metrics, time, metricGrpPrefix, channelBuilder, logContext), new Selector(config.getLong(CommonClientConfigs.CONNECTIONS_MAX_IDLE_MS_CONFIG), metrics, time, metricGrpPrefix, channelBuilder, logContext),
this.metadata, this.metadata,

View File

@ -432,7 +432,7 @@ object AdminClient {
val time = Time.SYSTEM val time = Time.SYSTEM
val metrics = new Metrics(time) val metrics = new Metrics(time)
val metadata = new Metadata(100L, 60 * 60 * 1000L, true) val metadata = new Metadata(100L, 60 * 60 * 1000L, true)
val channelBuilder = ClientUtils.createChannelBuilder(config) val channelBuilder = ClientUtils.createChannelBuilder(config, time)
val requestTimeoutMs = config.getInt(CommonClientConfigs.REQUEST_TIMEOUT_MS_CONFIG) val requestTimeoutMs = config.getInt(CommonClientConfigs.REQUEST_TIMEOUT_MS_CONFIG)
val retryBackoffMs = config.getLong(CommonClientConfigs.RETRY_BACKOFF_MS_CONFIG) val retryBackoffMs = config.getLong(CommonClientConfigs.RETRY_BACKOFF_MS_CONFIG)

View File

@ -116,6 +116,7 @@ class ControllerChannelManager(controllerContext: ControllerContext, config: Kaf
config, config,
config.interBrokerListenerName, config.interBrokerListenerName,
config.saslMechanismInterBrokerProtocol, config.saslMechanismInterBrokerProtocol,
time,
config.saslInterBrokerHandshakeRequestEnable config.saslInterBrokerHandshakeRequestEnable
) )
val selector = new Selector( val selector = new Selector(

View File

@ -51,6 +51,7 @@ object TransactionMarkerChannelManager {
config, config,
config.interBrokerListenerName, config.interBrokerListenerName,
config.saslMechanismInterBrokerProtocol, config.saslMechanismInterBrokerProtocol,
time,
config.saslInterBrokerHandshakeRequestEnable config.saslInterBrokerHandshakeRequestEnable
) )
val selector = new Selector( val selector = new Selector(

View File

@ -23,6 +23,7 @@ import java.nio.channels._
import java.nio.channels.{Selector => NSelector} import java.nio.channels.{Selector => NSelector}
import java.util.concurrent._ import java.util.concurrent._
import java.util.concurrent.atomic._ import java.util.concurrent.atomic._
import java.util.function.Supplier
import com.yammer.metrics.core.Gauge import com.yammer.metrics.core.Gauge
import kafka.cluster.{BrokerEndPoint, EndPoint} import kafka.cluster.{BrokerEndPoint, EndPoint}
@ -35,8 +36,10 @@ import org.apache.kafka.common.{KafkaException, Reconfigurable}
import org.apache.kafka.common.memory.{MemoryPool, SimpleMemoryPool} import org.apache.kafka.common.memory.{MemoryPool, SimpleMemoryPool}
import org.apache.kafka.common.metrics._ import org.apache.kafka.common.metrics._
import org.apache.kafka.common.metrics.stats.Meter import org.apache.kafka.common.metrics.stats.Meter
import org.apache.kafka.common.metrics.stats.Total
import org.apache.kafka.common.network.KafkaChannel.ChannelMuteEvent import org.apache.kafka.common.network.KafkaChannel.ChannelMuteEvent
import org.apache.kafka.common.network.{ChannelBuilder, ChannelBuilders, KafkaChannel, ListenerName, Selectable, Send, Selector => KSelector} import org.apache.kafka.common.network.{ChannelBuilder, ChannelBuilders, KafkaChannel, ListenerName, Selectable, Send, Selector => KSelector}
import org.apache.kafka.common.protocol.ApiKeys
import org.apache.kafka.common.requests.{RequestContext, RequestHeader} import org.apache.kafka.common.requests.{RequestContext, RequestHeader}
import org.apache.kafka.common.security.auth.SecurityProtocol import org.apache.kafka.common.security.auth.SecurityProtocol
import org.apache.kafka.common.utils.{KafkaThread, LogContext, Time} import org.apache.kafka.common.utils.{KafkaThread, LogContext, Time}
@ -117,6 +120,19 @@ class SocketServer(val config: KafkaConfig, val metrics: Metrics, val time: Time
def value = memoryPool.size() - memoryPool.availableMemory() def value = memoryPool.size() - memoryPool.availableMemory()
} }
) )
newGauge("ExpiredConnectionsKilledCount",
new Gauge[Double] {
def value = SocketServer.this.synchronized {
val expiredConnectionsKilledCountMetricNames = processors.values.asScala.map { p =>
metrics.metricName("expired-connections-killed-count", "socket-server-metrics", p.metricTags)
}
expiredConnectionsKilledCountMetricNames.map { metricName =>
Option(metrics.metric(metricName)).fold(0.0)(m => m.metricValue.asInstanceOf[Double])
}.sum
}
}
)
info("Started " + acceptors.size + " acceptor threads") info("Started " + acceptors.size + " acceptor threads")
} }
@ -548,6 +564,10 @@ private[kafka] class Processor(val id: Int,
// also includes the listener name) // also includes the listener name)
Map(NetworkProcessorMetricTag -> id.toString) Map(NetworkProcessorMetricTag -> id.toString)
) )
val expiredConnectionsKilledCount = new Total()
private val expiredConnectionsKilledCountMetricName = metrics.metricName("expired-connections-killed-count", "socket-server-metrics", metricTags)
metrics.addMetric(expiredConnectionsKilledCountMetricName, expiredConnectionsKilledCount)
private val selector = createSelector( private val selector = createSelector(
ChannelBuilders.serverChannelBuilder(listenerName, ChannelBuilders.serverChannelBuilder(listenerName,
@ -555,7 +575,8 @@ private[kafka] class Processor(val id: Int,
securityProtocol, securityProtocol,
config, config,
credentialProvider.credentialCache, credentialProvider.credentialCache,
credentialProvider.tokenCache)) credentialProvider.tokenCache,
time))
// Visible to override for testing // Visible to override for testing
protected[network] def createSelector(channelBuilder: ChannelBuilder): KSelector = { protected[network] def createSelector(channelBuilder: ChannelBuilder): KSelector = {
channelBuilder match { channelBuilder match {
@ -685,6 +706,10 @@ private[kafka] class Processor(val id: Int,
} }
} }
private def nowNanosSupplier = new Supplier[java.lang.Long] {
override def get(): java.lang.Long = time.nanoseconds()
}
private def poll() { private def poll() {
try selector.poll(300) try selector.poll(300)
catch { catch {
@ -701,14 +726,25 @@ private[kafka] class Processor(val id: Int,
openOrClosingChannel(receive.source) match { openOrClosingChannel(receive.source) match {
case Some(channel) => case Some(channel) =>
val header = RequestHeader.parse(receive.payload) val header = RequestHeader.parse(receive.payload)
val connectionId = receive.source if (header.apiKey() == ApiKeys.SASL_HANDSHAKE && channel.maybeBeginServerReauthentication(receive, nowNanosSupplier))
val context = new RequestContext(header, connectionId, channel.socketAddress, trace(s"Begin re-authentication: $channel")
channel.principal, listenerName, securityProtocol) else {
val req = new RequestChannel.Request(processor = id, context = context, val nowNanos = time.nanoseconds()
startTimeNanos = time.nanoseconds, memoryPool, receive.payload, requestChannel.metrics) if (channel.serverAuthenticationSessionExpired(nowNanos)) {
requestChannel.sendRequest(req) channel.disconnect()
selector.mute(connectionId) debug(s"Disconnected expired channel: $channel : $header")
handleChannelMuteEvent(connectionId, ChannelMuteEvent.REQUEST_RECEIVED) expiredConnectionsKilledCount.record(null, 1, 0)
} else {
val connectionId = receive.source
val context = new RequestContext(header, connectionId, channel.socketAddress,
channel.principal, listenerName, securityProtocol)
val req = new RequestChannel.Request(processor = id, context = context,
startTimeNanos = nowNanos, memoryPool, receive.payload, requestChannel.metrics)
requestChannel.sendRequest(req)
selector.mute(connectionId)
handleChannelMuteEvent(connectionId, ChannelMuteEvent.REQUEST_RECEIVED)
}
}
case None => case None =>
// This should never happen since completed receives are processed immediately after `poll()` // This should never happen since completed receives are processed immediately after `poll()`
throw new IllegalStateException(s"Channel ${receive.source} removed from selector before processing completed receive") throw new IllegalStateException(s"Channel ${receive.source} removed from selector before processing completed receive")
@ -883,6 +919,7 @@ private[kafka] class Processor(val id: Int,
override def shutdown(): Unit = { override def shutdown(): Unit = {
super.shutdown() super.shutdown()
removeMetric("IdlePercent", Map("networkProcessor" -> id.toString)) removeMetric("IdlePercent", Map("networkProcessor" -> id.toString))
metrics.removeMetric(expiredConnectionsKilledCountMetricName)
} }
} }

View File

@ -222,6 +222,9 @@ object Defaults {
val SslClientAuth = SslClientAuthNone val SslClientAuth = SslClientAuthNone
val SslPrincipalMappingRules = BrokerSecurityConfigs.DEFAULT_SSL_PRINCIPAL_MAPPING_RULES val SslPrincipalMappingRules = BrokerSecurityConfigs.DEFAULT_SSL_PRINCIPAL_MAPPING_RULES
/** ********* General Security configuration ***********/
val ConnectionsMaxReauthMsDefault = 0L
/** ********* Sasl configuration ***********/ /** ********* Sasl configuration ***********/
val SaslMechanismInterBrokerProtocol = SaslConfigs.DEFAULT_SASL_MECHANISM val SaslMechanismInterBrokerProtocol = SaslConfigs.DEFAULT_SASL_MECHANISM
val SaslEnabledMechanisms = SaslConfigs.DEFAULT_SASL_ENABLED_MECHANISMS val SaslEnabledMechanisms = SaslConfigs.DEFAULT_SASL_ENABLED_MECHANISMS
@ -422,6 +425,7 @@ object KafkaConfig {
/** ******** Common Security Configuration *************/ /** ******** Common Security Configuration *************/
val PrincipalBuilderClassProp = BrokerSecurityConfigs.PRINCIPAL_BUILDER_CLASS_CONFIG val PrincipalBuilderClassProp = BrokerSecurityConfigs.PRINCIPAL_BUILDER_CLASS_CONFIG
val ConnectionsMaxReauthMsProp = BrokerSecurityConfigs.CONNECTIONS_MAX_REAUTH_MS
/** ********* SSL Configuration ****************/ /** ********* SSL Configuration ****************/
val SslProtocolProp = SslConfigs.SSL_PROTOCOL_CONFIG val SslProtocolProp = SslConfigs.SSL_PROTOCOL_CONFIG
@ -744,6 +748,7 @@ object KafkaConfig {
/** ******** Common Security Configuration *************/ /** ******** Common Security Configuration *************/
val PrincipalBuilderClassDoc = BrokerSecurityConfigs.PRINCIPAL_BUILDER_CLASS_DOC val PrincipalBuilderClassDoc = BrokerSecurityConfigs.PRINCIPAL_BUILDER_CLASS_DOC
val ConnectionsMaxReauthMsDoc = BrokerSecurityConfigs.CONNECTIONS_MAX_REAUTH_MS_DOC
/** ********* SSL Configuration ****************/ /** ********* SSL Configuration ****************/
val SslProtocolDoc = SslConfigs.SSL_PROTOCOL_DOC val SslProtocolDoc = SslConfigs.SSL_PROTOCOL_DOC
@ -983,6 +988,9 @@ object KafkaConfig {
.define(AlterLogDirsReplicationQuotaWindowSizeSecondsProp, INT, Defaults.AlterLogDirsReplicationQuotaWindowSizeSeconds, atLeast(1), LOW, AlterLogDirsReplicationQuotaWindowSizeSecondsDoc) .define(AlterLogDirsReplicationQuotaWindowSizeSecondsProp, INT, Defaults.AlterLogDirsReplicationQuotaWindowSizeSeconds, atLeast(1), LOW, AlterLogDirsReplicationQuotaWindowSizeSecondsDoc)
.define(ClientQuotaCallbackClassProp, CLASS, null, LOW, ClientQuotaCallbackClassDoc) .define(ClientQuotaCallbackClassProp, CLASS, null, LOW, ClientQuotaCallbackClassDoc)
/** ********* General Security Configuration ****************/
.define(ConnectionsMaxReauthMsProp, LONG, Defaults.ConnectionsMaxReauthMsDefault, MEDIUM, ConnectionsMaxReauthMsDoc)
/** ********* SSL Configuration ****************/ /** ********* SSL Configuration ****************/
.define(PrincipalBuilderClassProp, CLASS, null, MEDIUM, PrincipalBuilderClassDoc) .define(PrincipalBuilderClassProp, CLASS, null, MEDIUM, PrincipalBuilderClassDoc)
.define(SslProtocolProp, STRING, Defaults.SslProtocol, MEDIUM, SslProtocolDoc) .define(SslProtocolProp, STRING, Defaults.SslProtocol, MEDIUM, SslProtocolDoc)

View File

@ -421,6 +421,7 @@ class KafkaServer(val config: KafkaConfig, time: Time = Time.SYSTEM, threadNameP
config, config,
config.interBrokerListenerName, config.interBrokerListenerName,
config.saslMechanismInterBrokerProtocol, config.saslMechanismInterBrokerProtocol,
time,
config.saslInterBrokerHandshakeRequestEnable) config.saslInterBrokerHandshakeRequestEnable)
val selector = new Selector( val selector = new Selector(
NetworkReceive.UNLIMITED, NetworkReceive.UNLIMITED,

View File

@ -56,6 +56,7 @@ class ReplicaFetcherBlockingSend(sourceBroker: BrokerEndPoint,
brokerConfig, brokerConfig,
brokerConfig.interBrokerListenerName, brokerConfig.interBrokerListenerName,
brokerConfig.saslMechanismInterBrokerProtocol, brokerConfig.saslMechanismInterBrokerProtocol,
time,
brokerConfig.saslInterBrokerHandshakeRequestEnable brokerConfig.saslInterBrokerHandshakeRequestEnable
) )
val selector = new Selector( val selector = new Selector(

View File

@ -450,7 +450,7 @@ private class ReplicaFetcherBlockingSend(sourceNode: Node,
private val socketTimeout: Int = consumerConfig.getInt(ConsumerConfig.REQUEST_TIMEOUT_MS_CONFIG) private val socketTimeout: Int = consumerConfig.getInt(ConsumerConfig.REQUEST_TIMEOUT_MS_CONFIG)
private val networkClient = { private val networkClient = {
val channelBuilder = org.apache.kafka.clients.ClientUtils.createChannelBuilder(consumerConfig) val channelBuilder = org.apache.kafka.clients.ClientUtils.createChannelBuilder(consumerConfig, time)
val selector = new Selector( val selector = new Selector(
NetworkReceive.UNLIMITED, NetworkReceive.UNLIMITED,
consumerConfig.getLong(ConsumerConfig.CONNECTIONS_MAX_IDLE_MS_CONFIG), consumerConfig.getLong(ConsumerConfig.CONNECTIONS_MAX_IDLE_MS_CONFIG),

View File

@ -17,6 +17,9 @@
package kafka.api package kafka.api
import com.yammer.metrics.Metrics
import com.yammer.metrics.core.{Gauge, Metric, MetricName}
import java.io.File import java.io.File
import java.util.ArrayList import java.util.ArrayList
import java.util.concurrent.ExecutionException import java.util.concurrent.ExecutionException
@ -170,6 +173,7 @@ abstract class EndToEndAuthorizationTest extends IntegrationTestHarness with Sas
this.serverConfig.setProperty(KafkaConfig.OffsetsTopicReplicationFactorProp, "3") this.serverConfig.setProperty(KafkaConfig.OffsetsTopicReplicationFactorProp, "3")
this.serverConfig.setProperty(KafkaConfig.MinInSyncReplicasProp, "3") this.serverConfig.setProperty(KafkaConfig.MinInSyncReplicasProp, "3")
this.serverConfig.setProperty(KafkaConfig.DefaultReplicationFactorProp, "3") this.serverConfig.setProperty(KafkaConfig.DefaultReplicationFactorProp, "3")
this.serverConfig.setProperty(KafkaConfig.ConnectionsMaxReauthMsProp, "1500")
this.consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "group") this.consumerConfig.setProperty(ConsumerConfig.GROUP_ID_CONFIG, "group")
/** /**
@ -204,6 +208,27 @@ abstract class EndToEndAuthorizationTest extends IntegrationTestHarness with Sas
val consumer = createConsumer() val consumer = createConsumer()
consumer.assign(List(tp).asJava) consumer.assign(List(tp).asJava)
consumeRecords(consumer, numRecords) consumeRecords(consumer, numRecords)
confirmReauthenticationMetrics
}
protected def confirmReauthenticationMetrics() : Unit = {
val expiredConnectionsKilledCountTotal = getGauge("ExpiredConnectionsKilledCount").value()
servers.foreach { s =>
val numExpiredKilled = TestUtils.totalMetricValue(s, "expired-connections-killed-count")
assertTrue("Should have been zero expired connections killed: " + numExpiredKilled + "(total=" + expiredConnectionsKilledCountTotal + ")", numExpiredKilled == 0)
}
assertEquals("Should have been zero expired connections killed total", 0, expiredConnectionsKilledCountTotal, 0.0)
servers.foreach { s =>
assertTrue("failed re-authentications not 0", TestUtils.totalMetricValue(s, "failed-reauthentication-total") == 0)
}
}
private def getGauge(metricName: String) = {
Metrics.defaultRegistry.allMetrics.asScala
.filterKeys(k => k.getName == metricName)
.headOption
.getOrElse { fail( "Unable to find metric " + metricName ) }
._2.asInstanceOf[Gauge[Double]]
} }
@Test @Test
@ -212,6 +237,7 @@ abstract class EndToEndAuthorizationTest extends IntegrationTestHarness with Sas
val consumer = createConsumer() val consumer = createConsumer()
consumer.subscribe(List(topic).asJava) consumer.subscribe(List(topic).asJava)
consumeRecords(consumer, numRecords) consumeRecords(consumer, numRecords)
confirmReauthenticationMetrics
} }
@Test @Test
@ -222,6 +248,7 @@ abstract class EndToEndAuthorizationTest extends IntegrationTestHarness with Sas
val consumer = createConsumer() val consumer = createConsumer()
consumer.subscribe(List(topic).asJava) consumer.subscribe(List(topic).asJava)
consumeRecords(consumer, numRecords) consumeRecords(consumer, numRecords)
confirmReauthenticationMetrics
} }
@Test @Test
@ -232,6 +259,7 @@ abstract class EndToEndAuthorizationTest extends IntegrationTestHarness with Sas
val consumer = createConsumer() val consumer = createConsumer()
consumer.subscribe(List(topic).asJava) consumer.subscribe(List(topic).asJava)
consumeRecords(consumer, numRecords) consumeRecords(consumer, numRecords)
confirmReauthenticationMetrics
} }
@Test @Test
@ -242,6 +270,7 @@ abstract class EndToEndAuthorizationTest extends IntegrationTestHarness with Sas
val consumer = createConsumer() val consumer = createConsumer()
consumer.assign(List(tp2).asJava) consumer.assign(List(tp2).asJava)
consumeRecords(consumer, numRecords, topic = tp2.topic) consumeRecords(consumer, numRecords, topic = tp2.topic)
confirmReauthenticationMetrics
} }
private def setWildcardResourceAcls() { private def setWildcardResourceAcls() {
@ -280,6 +309,7 @@ abstract class EndToEndAuthorizationTest extends IntegrationTestHarness with Sas
def testNoProduceWithoutDescribeAcl(): Unit = { def testNoProduceWithoutDescribeAcl(): Unit = {
val producer = createProducer() val producer = createProducer()
sendRecords(producer, numRecords, tp) sendRecords(producer, numRecords, tp)
confirmReauthenticationMetrics
} }
@Test @Test
@ -296,6 +326,7 @@ abstract class EndToEndAuthorizationTest extends IntegrationTestHarness with Sas
case e: TopicAuthorizationException => case e: TopicAuthorizationException =>
assertEquals(Set(topic).asJava, e.unauthorizedTopics()) assertEquals(Set(topic).asJava, e.unauthorizedTopics())
} }
confirmReauthenticationMetrics
} }
/** /**
@ -309,6 +340,7 @@ abstract class EndToEndAuthorizationTest extends IntegrationTestHarness with Sas
consumer.assign(List(tp).asJava) consumer.assign(List(tp).asJava)
// the exception is expected when the consumer attempts to lookup offsets // the exception is expected when the consumer attempts to lookup offsets
consumeRecords(consumer) consumeRecords(consumer)
confirmReauthenticationMetrics
} }
@Test(expected = classOf[TopicAuthorizationException]) @Test(expected = classOf[TopicAuthorizationException])
@ -351,6 +383,7 @@ abstract class EndToEndAuthorizationTest extends IntegrationTestHarness with Sas
case e: TopicAuthorizationException => case e: TopicAuthorizationException =>
assertEquals(Set(topic).asJava, e.unauthorizedTopics()) assertEquals(Set(topic).asJava, e.unauthorizedTopics())
} }
confirmReauthenticationMetrics
} }
@Test @Test
@ -366,6 +399,7 @@ abstract class EndToEndAuthorizationTest extends IntegrationTestHarness with Sas
case e: TopicAuthorizationException => case e: TopicAuthorizationException =>
assertEquals(Set(topic).asJava, e.unauthorizedTopics()) assertEquals(Set(topic).asJava, e.unauthorizedTopics())
} }
confirmReauthenticationMetrics
} }
private def noConsumeWithDescribeAclSetup(): Unit = { private def noConsumeWithDescribeAclSetup(): Unit = {
@ -401,6 +435,7 @@ abstract class EndToEndAuthorizationTest extends IntegrationTestHarness with Sas
case e: GroupAuthorizationException => case e: GroupAuthorizationException =>
assertEquals(group, e.groupId()) assertEquals(group, e.groupId())
} }
confirmReauthenticationMetrics
} }
protected final def sendRecords(producer: KafkaProducer[Array[Byte], Array[Byte]], protected final def sendRecords(producer: KafkaProducer[Array[Byte], Array[Byte]],

View File

@ -74,5 +74,6 @@ abstract class SaslEndToEndAuthorizationTest extends EndToEndAuthorizationTest {
case e: TopicAuthorizationException => assertTrue(e.unauthorizedTopics.contains(topic)) case e: TopicAuthorizationException => assertTrue(e.unauthorizedTopics.contains(topic))
case e: GroupAuthorizationException => assertEquals(group, e.groupId) case e: GroupAuthorizationException => assertEquals(group, e.groupId)
} }
confirmReauthenticationMetrics
} }
} }

View File

@ -188,7 +188,7 @@ class GssapiAuthenticationTest extends IntegrationTestHarness with SaslSetup {
private def createSelector(): Selector = { private def createSelector(): Selector = {
val channelBuilder = ChannelBuilders.clientChannelBuilder(securityProtocol, val channelBuilder = ChannelBuilders.clientChannelBuilder(securityProtocol,
JaasContext.Type.CLIENT, new TestSecurityConfig(clientConfig), null, kafkaClientSaslMechanism, true) JaasContext.Type.CLIENT, new TestSecurityConfig(clientConfig), null, kafkaClientSaslMechanism, time, true)
NetworkTestUtils.createSelector(channelBuilder, time) NetworkTestUtils.createSelector(channelBuilder, time)
} }
} }

View File

@ -26,6 +26,7 @@ import org.apache.kafka.common.config.types.Password
import org.apache.kafka.common.internals.FatalExitError import org.apache.kafka.common.internals.FatalExitError
import org.junit.{After, Before, Test} import org.junit.{After, Before, Test}
import org.junit.Assert._ import org.junit.Assert._
import org.apache.kafka.common.config.internals.BrokerSecurityConfigs
class KafkaTest { class KafkaTest {
@ -108,6 +109,21 @@ class KafkaTest {
assertEquals(password, config.getPassword(KafkaConfig.SslTruststorePasswordProp).value) assertEquals(password, config.getPassword(KafkaConfig.SslTruststorePasswordProp).value)
} }
@Test
def testConnectionsMaxReauthMsDefault(): Unit = {
val propertiesFile = prepareDefaultConfig()
val config = KafkaConfig.fromProps(Kafka.getPropsFromArgs(Array(propertiesFile)))
assertEquals(0L, config.valuesWithPrefixOverride("sasl_ssl.oauthbearer.").get(BrokerSecurityConfigs.CONNECTIONS_MAX_REAUTH_MS).asInstanceOf[Long])
}
@Test
def testConnectionsMaxReauthMsExplicit(): Unit = {
val propertiesFile = prepareDefaultConfig()
val expected = 3600000
val config = KafkaConfig.fromProps(Kafka.getPropsFromArgs(Array(propertiesFile, "--override", s"sasl_ssl.oauthbearer.connections.max.reauth.ms=${expected}")))
assertEquals(expected, config.valuesWithPrefixOverride("sasl_ssl.oauthbearer.").get(BrokerSecurityConfigs.CONNECTIONS_MAX_REAUTH_MS).asInstanceOf[Long])
}
def prepareDefaultConfig(): String = { def prepareDefaultConfig(): String = {
prepareConfig(Array("broker.id=1", "zookeeper.connect=somewhere")) prepareConfig(Array("broker.id=1", "zookeeper.connect=somewhere"))
} }

View File

@ -675,6 +675,7 @@ class KafkaConfigTest {
case KafkaConfig.RackProp => // ignore string case KafkaConfig.RackProp => // ignore string
//SSL Configs //SSL Configs
case KafkaConfig.PrincipalBuilderClassProp => case KafkaConfig.PrincipalBuilderClassProp =>
case KafkaConfig.ConnectionsMaxReauthMsProp =>
case KafkaConfig.SslProtocolProp => // ignore string case KafkaConfig.SslProtocolProp => // ignore string
case KafkaConfig.SslProviderProp => // ignore string case KafkaConfig.SslProviderProp => // ignore string
case KafkaConfig.SslEnabledProtocolsProp => case KafkaConfig.SslEnabledProtocolsProp =>

View File

@ -968,6 +968,16 @@
<td>kafka.network:type=SocketServer,name=NetworkProcessorAvgIdlePercent</td> <td>kafka.network:type=SocketServer,name=NetworkProcessorAvgIdlePercent</td>
<td>between 0 and 1, ideally &gt 0.3</td> <td>between 0 and 1, ideally &gt 0.3</td>
</tr> </tr>
<tr>
<td>The number of connections disconnected on a processor due to a client not re-authenticating and then using the connection beyond its expiration time for anything other than re-authentication</td>
<td>kafka.server:type=socket-server-metrics,listener=[SASL_PLAINTEXT|SASL_SSL],networkProcessor=&lt;#&gt;,name=expired-connections-killed-count</td>
<td>ideally 0 when re-authentication is enabled, implying there are no longer any older, pre-2.2.0 clients connecting to this (listener, processor) combination</td>
</tr>
<tr>
<td>The total number of connections disconnected, across all processors, due to a client not re-authenticating and then using the connection beyond its expiration time for anything other than re-authentication</td>
<td>kafka.network:type=SocketServer,name=ExpiredConnectionsKilledCount</td>
<td>ideally 0 when re-authentication is enabled, implying there are no longer any older, pre-2.2.0 clients connecting to this broker</td>
</tr>
<tr> <tr>
<td>The average fraction of time the request handler threads are idle</td> <td>The average fraction of time the request handler threads are idle</td>
<td>kafka.server:type=KafkaRequestHandlerPool,name=RequestHandlerAvgIdlePercent</td> <td>kafka.server:type=KafkaRequestHandlerPool,name=RequestHandlerAvgIdlePercent</td>
@ -1152,6 +1162,41 @@
<td>Total connections that failed authentication.</td> <td>Total connections that failed authentication.</td>
<td>kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)</td> <td>kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)</td>
</tr> </tr>
<tr>
<td>successful-reauthentication-rate</td>
<td>Connections per second that were successfully re-authenticated using SASL.</td>
<td>kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)</td>
</tr>
<tr>
<td>successful-reauthentication-total</td>
<td>Total connections that were successfully re-authenticated using SASL.</td>
<td>kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)</td>
</tr>
<tr>
<td>reauthentication-latency-max</td>
<td>The maximum latency in ms observed due to re-authentication.</td>
<td>kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)</td>
</tr>
<tr>
<td>reauthentication-latency-avg</td>
<td>The average latency in ms observed due to re-authentication.</td>
<td>kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)</td>
</tr>
<tr>
<td>failed-reauthentication-rate</td>
<td>Connections per second that failed re-authentication.</td>
<td>kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)</td>
</tr>
<tr>
<td>failed-reauthentication-total</td>
<td>Total connections that failed re-authentication.</td>
<td>kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)</td>
</tr>
<tr>
<td>successful-authentication-no-reauth-total</td>
<td>Total connections that were successfully authenticated by older, pre-2.2.0 SASL clients that do not support re-authentication. May only be non-zero </td>
<td>kafka.[producer|consumer|connect]:type=[producer|consumer|connect]-metrics,client-id=([-.\w]+)</td>
</tr>
</tbody> </tbody>
</table> </table>

View File

@ -57,6 +57,7 @@ import java.util.concurrent.atomic.AtomicBoolean;
public class ConnectionStressWorker implements TaskWorker { public class ConnectionStressWorker implements TaskWorker {
private static final Logger log = LoggerFactory.getLogger(ConnectionStressWorker.class); private static final Logger log = LoggerFactory.getLogger(ConnectionStressWorker.class);
private static final Time TIME = Time.SYSTEM;
private static final int THROTTLE_PERIOD_MS = 100; private static final int THROTTLE_PERIOD_MS = 100;
@ -100,7 +101,7 @@ public class ConnectionStressWorker implements TaskWorker {
this.status = status; this.status = status;
this.totalConnections = 0; this.totalConnections = 0;
this.totalFailedConnections = 0; this.totalFailedConnections = 0;
this.startTimeMs = Time.SYSTEM.milliseconds(); this.startTimeMs = TIME.milliseconds();
this.throttle = new ConnectStressThrottle(WorkerUtils. this.throttle = new ConnectStressThrottle(WorkerUtils.
perSecToPerPeriod(spec.targetConnectionsPerSec(), THROTTLE_PERIOD_MS)); perSecToPerPeriod(spec.targetConnectionsPerSec(), THROTTLE_PERIOD_MS));
this.nextReportTime = 0; this.nextReportTime = 0;
@ -168,11 +169,11 @@ public class ConnectionStressWorker implements TaskWorker {
try { try {
List<Node> nodes = updater.fetchNodes(); List<Node> nodes = updater.fetchNodes();
Node targetNode = nodes.get(ThreadLocalRandom.current().nextInt(nodes.size())); Node targetNode = nodes.get(ThreadLocalRandom.current().nextInt(nodes.size()));
try (ChannelBuilder channelBuilder = ClientUtils.createChannelBuilder(conf)) { try (ChannelBuilder channelBuilder = ClientUtils.createChannelBuilder(conf, TIME)) {
try (Metrics metrics = new Metrics()) { try (Metrics metrics = new Metrics()) {
LogContext logContext = new LogContext(); LogContext logContext = new LogContext();
try (Selector selector = new Selector(conf.getLong(AdminClientConfig.CONNECTIONS_MAX_IDLE_MS_CONFIG), try (Selector selector = new Selector(conf.getLong(AdminClientConfig.CONNECTIONS_MAX_IDLE_MS_CONFIG),
metrics, Time.SYSTEM, "", channelBuilder, logContext)) { metrics, TIME, "", channelBuilder, logContext)) {
try (NetworkClient client = new NetworkClient(selector, try (NetworkClient client = new NetworkClient(selector,
updater, updater,
"ConnectionStressWorker", "ConnectionStressWorker",
@ -183,11 +184,11 @@ public class ConnectionStressWorker implements TaskWorker {
4096, 4096,
1000, 1000,
ClientDnsLookup.forConfig(conf.getString(AdminClientConfig.CLIENT_DNS_LOOKUP_CONFIG)), ClientDnsLookup.forConfig(conf.getString(AdminClientConfig.CLIENT_DNS_LOOKUP_CONFIG)),
Time.SYSTEM, TIME,
false, false,
new ApiVersions(), new ApiVersions(),
logContext)) { logContext)) {
NetworkClientUtils.awaitReady(client, targetNode, Time.SYSTEM, 100); NetworkClientUtils.awaitReady(client, targetNode, TIME, 100);
} }
} }
} }