KAFKA-14604: SASL session expiration time will be overflowed when calculation (#18526)
CI / build (push) Has been cancelled Details
Fixup PR Labels / fixup-pr-labels (needs-attention) (push) Has been cancelled Details
Fixup PR Labels / fixup-pr-labels (triage) (push) Has been cancelled Details
Docker Image CVE Scanner / scan_jvm (3.7.2) (push) Has been cancelled Details
Docker Image CVE Scanner / scan_jvm (3.8.1) (push) Has been cancelled Details
Docker Image CVE Scanner / scan_jvm (3.9.1) (push) Has been cancelled Details
Docker Image CVE Scanner / scan_jvm (4.0.0) (push) Has been cancelled Details
Docker Image CVE Scanner / scan_jvm (latest) (push) Has been cancelled Details
Flaky Test Report / Flaky Test Report (push) Has been cancelled Details
Fixup PR Labels / needs-attention (push) Has been cancelled Details

The timeout value may be overflowed if users set a large expiration
time.

```
sessionExpirationTimeNanos = authenticationEndNanos + 1000 * 1000 *
sessionLifetimeMs;
```

Fixed it by throwing exception if the value is overflowed.

Reviewers: TaiJuWu <tjwu1217@gmail.com>, Luke Chen <showuon@gmail.com>,
 TengYao Chi <frankvicky@apache.org>

Signed-off-by: PoAn Yang <payang@apache.org>
This commit is contained in:
PoAn Yang 2025-08-03 19:12:04 +08:00 committed by GitHub
parent 3f1d830174
commit ea771563e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 136 additions and 14 deletions

View File

@ -690,7 +690,7 @@ public class SaslClientAuthenticator implements Authenticator {
double pctToUse = pctWindowFactorToTakeNetworkLatencyAndClockDriftIntoAccount + RNG.nextDouble()
* pctWindowJitterToAvoidReauthenticationStormAcrossManyChannelsSimultaneously;
sessionLifetimeMsToUse = (long) (positiveSessionLifetimeMs * pctToUse);
clientSessionReauthenticationTimeNanos = authenticationEndNanos + 1000 * 1000 * sessionLifetimeMsToUse;
clientSessionReauthenticationTimeNanos = Math.addExact(authenticationEndNanos, Utils.msToNs(sessionLifetimeMsToUse));
log.debug(
"Finished {} with session expiration in {} ms and session re-authentication on or after {} ms",
authenticationOrReauthenticationText(), positiveSessionLifetimeMs, sessionLifetimeMsToUse);

View File

@ -681,7 +681,7 @@ public class SaslServerAuthenticator implements Authenticator {
else
retvalSessionLifetimeMs = zeroIfNegative(Math.min(credentialExpirationMs - authenticationEndMs, connectionsMaxReauthMs));
sessionExpirationTimeNanos = authenticationEndNanos + 1000 * 1000 * retvalSessionLifetimeMs;
sessionExpirationTimeNanos = Math.addExact(authenticationEndNanos, Utils.msToNs(retvalSessionLifetimeMs));
}
if (credentialExpirationMs != null) {

View File

@ -1719,4 +1719,17 @@ public final class Utils {
public interface ThrowingRunnable {
void run() throws Exception;
}
/**
* convert millisecond to nanosecond, or throw exception if overflow
* @param timeMs the time in millisecond
* @return the converted nanosecond
*/
public static long msToNs(long timeMs) {
try {
return Math.multiplyExact(1000 * 1000, timeMs);
} catch (ArithmeticException e) {
throw new IllegalArgumentException("Cannot convert " + timeMs + " millisecond to nanosecond due to arithmetic overflow", e);
}
}
}

View File

@ -158,6 +158,7 @@ public class SaslAuthenticatorTest {
private static final long CONNECTIONS_MAX_REAUTH_MS_VALUE = 100L;
private static final int BUFFER_SIZE = 4 * 1024;
private static Time time = Time.SYSTEM;
private static boolean needLargeExpiration = false;
private NioEchoServer server;
private Selector selector;
@ -181,6 +182,7 @@ public class SaslAuthenticatorTest {
@AfterEach
public void teardown() throws Exception {
needLargeExpiration = false;
if (server != null)
this.server.close();
if (selector != null)
@ -1610,6 +1612,42 @@ public class SaslAuthenticatorTest {
server.verifyReauthenticationMetrics(0, 1);
}
@Test
public void testReauthenticateWithLargeReauthValue() throws Exception {
// enable it, we'll get a large expiration timestamp token
needLargeExpiration = true;
String node = "0";
SecurityProtocol securityProtocol = SecurityProtocol.SASL_SSL;
configureMechanisms(OAuthBearerLoginModule.OAUTHBEARER_MECHANISM,
List.of(OAuthBearerLoginModule.OAUTHBEARER_MECHANISM));
// set a large re-auth timeout in server side
saslServerConfigs.put(BrokerSecurityConfigs.CONNECTIONS_MAX_REAUTH_MS_CONFIG, Long.MAX_VALUE);
server = createEchoServer(securityProtocol);
// set to default value for sasl login configs for initialization in ExpiringCredentialRefreshConfig
saslClientConfigs.put(SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_FACTOR, SaslConfigs.DEFAULT_LOGIN_REFRESH_WINDOW_FACTOR);
saslClientConfigs.put(SaslConfigs.SASL_LOGIN_REFRESH_WINDOW_JITTER, SaslConfigs.DEFAULT_LOGIN_REFRESH_WINDOW_JITTER);
saslClientConfigs.put(SaslConfigs.SASL_LOGIN_REFRESH_MIN_PERIOD_SECONDS, SaslConfigs.DEFAULT_LOGIN_REFRESH_MIN_PERIOD_SECONDS);
saslClientConfigs.put(SaslConfigs.SASL_LOGIN_REFRESH_BUFFER_SECONDS, SaslConfigs.DEFAULT_LOGIN_REFRESH_BUFFER_SECONDS);
saslClientConfigs.put(SaslConfigs.SASL_LOGIN_CALLBACK_HANDLER_CLASS, AlternateLoginCallbackHandler.class);
createCustomClientConnection(securityProtocol, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, node, true);
// channel should be not null before sasl handshake
assertNotNull(selector.channel(node));
TestUtils.waitForCondition(() -> {
selector.poll(1000);
// this channel should be closed due to session timeout calculation overflow
return selector.channel(node) == null;
}, "channel didn't close with large re-authentication value");
// ensure metrics are as expected
server.verifyAuthenticationMetrics(0, 0);
server.verifyReauthenticationMetrics(0, 0);
}
@Test
public void testCorrelationId() {
SaslClientAuthenticator authenticator = new SaslClientAuthenticator(
@ -2002,7 +2040,7 @@ public class SaslAuthenticatorTest {
if (enableSaslAuthenticateHeader)
createClientConnection(securityProtocol, node);
else
createClientConnectionWithoutSaslAuthenticateHeader(securityProtocol, saslMechanism, node);
createCustomClientConnection(securityProtocol, saslMechanism, node, false);
}
private NioEchoServer startServerApiVersionsUnsupportedByClient(final SecurityProtocol securityProtocol, String saslMechanism) throws Exception {
@ -2090,15 +2128,13 @@ public class SaslAuthenticatorTest {
return server;
}
private void createClientConnectionWithoutSaslAuthenticateHeader(final SecurityProtocol securityProtocol,
final String saslMechanism, String node) throws Exception {
final ListenerName listenerName = ListenerName.forSecurityProtocol(securityProtocol);
final Map<String, ?> configs = Collections.emptyMap();
final JaasContext jaasContext = JaasContext.loadClientContext(configs);
final Map<String, JaasContext> jaasContexts = Collections.singletonMap(saslMechanism, jaasContext);
SaslChannelBuilder clientChannelBuilder = new SaslChannelBuilder(ConnectionMode.CLIENT, jaasContexts,
private SaslChannelBuilder saslChannelBuilderWithoutHeader(
final SecurityProtocol securityProtocol,
final String saslMechanism,
final Map<String, JaasContext> jaasContexts,
final ListenerName listenerName
) {
return new SaslChannelBuilder(ConnectionMode.CLIENT, jaasContexts,
securityProtocol, listenerName, false, saslMechanism,
null, null, null, time, new LogContext(), null) {
@ -2125,6 +2161,42 @@ public class SaslAuthenticatorTest {
};
}
};
}
private void createCustomClientConnection(
final SecurityProtocol securityProtocol,
final String saslMechanism,
String node,
boolean withSaslAuthenticateHeader
) throws Exception {
final ListenerName listenerName = ListenerName.forSecurityProtocol(securityProtocol);
final Map<String, ?> configs = Collections.emptyMap();
final JaasContext jaasContext = JaasContext.loadClientContext(configs);
final Map<String, JaasContext> jaasContexts = Collections.singletonMap(saslMechanism, jaasContext);
SaslChannelBuilder clientChannelBuilder;
if (!withSaslAuthenticateHeader) {
clientChannelBuilder = saslChannelBuilderWithoutHeader(securityProtocol, saslMechanism, jaasContexts, listenerName);
} else {
clientChannelBuilder = new SaslChannelBuilder(ConnectionMode.CLIENT, jaasContexts,
securityProtocol, listenerName, false, saslMechanism,
null, null, null, time, new LogContext(), null) {
@Override
protected SaslClientAuthenticator buildClientAuthenticator(Map<String, ?> configs,
AuthenticateCallbackHandler callbackHandler,
String id,
String serverHost,
String servicePrincipal,
TransportLayer transportLayer,
Subject subject) {
return new SaslClientAuthenticator(configs, callbackHandler, id, subject,
servicePrincipal, serverHost, saslMechanism, transportLayer, time, new LogContext());
}
};
}
clientChannelBuilder.configure(saslClientConfigs);
this.selector = NetworkTestUtils.createSelector(clientChannelBuilder, time);
InetSocketAddress addr = new InetSocketAddress("localhost", server.port());
@ -2581,10 +2653,11 @@ public class SaslAuthenticatorTest {
+ ++numInvocations;
String headerJson = "{" + claimOrHeaderJsonText("alg", "none") + "}";
/*
* Use a short lifetime so the background refresh thread replaces it before we
* If we're testing large expiration scenario, use a large lifetime.
* Otherwise, use a short lifetime so the background refresh thread replaces it before we
* re-authenticate
*/
String lifetimeSecondsValueToUse = "1";
String lifetimeSecondsValueToUse = needLargeExpiration ? String.valueOf(Long.MAX_VALUE) : "1";
String claimsJson;
try {
claimsJson = String.format("{%s,%s,%s}",

View File

@ -270,6 +270,35 @@ public class SaslServerAuthenticatorTest {
}
}
@Test
public void testSessionWontExpireWithLargeExpirationTime() throws IOException {
String mechanism = OAuthBearerLoginModule.OAUTHBEARER_MECHANISM;
SaslServer saslServer = mock(SaslServer.class);
MockTime time = new MockTime(0, 1, 1000);
// set a Long.MAX_VALUE as the expiration time
Duration largeExpirationTime = Duration.ofMillis(Long.MAX_VALUE);
try (
MockedStatic<?> ignored = mockSaslServer(saslServer, mechanism, time, largeExpirationTime);
MockedStatic<?> ignored2 = mockKafkaPrincipal("[principal-type]", "[principal-name");
TransportLayer transportLayer = mockTransportLayer()
) {
SaslServerAuthenticator authenticator = getSaslServerAuthenticatorForOAuth(mechanism, transportLayer, time, largeExpirationTime.toMillis());
mockRequest(saslHandshakeRequest(mechanism), transportLayer);
authenticator.authenticate();
when(saslServer.isComplete()).thenReturn(false).thenReturn(true);
mockRequest(saslAuthenticateRequest(), transportLayer);
Throwable t = assertThrows(IllegalArgumentException.class, () -> authenticator.authenticate());
assertEquals(ArithmeticException.class, t.getCause().getClass());
assertEquals("Cannot convert " + Long.MAX_VALUE + " millisecond to nanosecond due to arithmetic overflow",
t.getMessage());
}
}
private SaslServerAuthenticator getSaslServerAuthenticatorForOAuth(String mechanism, TransportLayer transportLayer, Time time, Long maxReauth) {
Map<String, ?> configs = Collections.singletonMap(BrokerSecurityConfigs.SASL_ENABLED_MECHANISMS_CONFIG,
Collections.singletonList(mechanism));

View File

@ -1269,6 +1269,13 @@ public class UtilsTest {
assertEquals(expected, recorded);
}
@Test
public void testMsToNs() {
assertEquals(1000000, Utils.msToNs(1));
assertEquals(0, Utils.msToNs(0));
assertThrows(IllegalArgumentException.class, () -> Utils.msToNs(Long.MAX_VALUE));
}
private Callable<Void> recordingCallable(Map<String, Object> recordingMap, String success, TestException failure) {
return () -> {
if (success == null)