From b4e1deb43a75ca84262d877c5f47bbf2b0dbc6c4 Mon Sep 17 00:00:00 2001 From: Vikas Singh Date: Mon, 9 Sep 2024 18:43:33 +0530 Subject: [PATCH] MINOR: Few cleanups Reviewers: Manikumar Reddy --- .../scram/internals/ScramSaslServer.java | 6 +- .../scram/internals/ScramSaslServerTest.java | 68 +++++++++++++++++++ 2 files changed, 73 insertions(+), 1 deletion(-) diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslServer.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslServer.java index cea3ddf71fd..e3a300f9a7b 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslServer.java +++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslServer.java @@ -149,6 +149,9 @@ public class ScramSaslServer implements SaslServer { case RECEIVE_CLIENT_FINAL_MESSAGE: try { ClientFinalMessage clientFinalMessage = new ClientFinalMessage(response); + if (!clientFinalMessage.nonce().endsWith(serverFirstMessage.nonce())) { + throw new SaslException("Invalid client nonce in the final client message."); + } verifyClientProof(clientFinalMessage); byte[] serverKey = scramCredential.serverKey(); byte[] serverSignature = formatter.serverSignature(serverKey, clientFirstMessage, serverFirstMessage, clientFinalMessage); @@ -222,7 +225,8 @@ public class ScramSaslServer implements SaslServer { this.state = state; } - private void verifyClientProof(ClientFinalMessage clientFinalMessage) throws SaslException { + // Visible for testing + void verifyClientProof(ClientFinalMessage clientFinalMessage) throws SaslException { try { byte[] expectedStoredKey = scramCredential.storedKey(); byte[] clientSignature = formatter.clientSignature(expectedStoredKey, clientFirstMessage, serverFirstMessage, clientFinalMessage); diff --git a/clients/src/test/java/org/apache/kafka/common/security/scram/internals/ScramSaslServerTest.java b/clients/src/test/java/org/apache/kafka/common/security/scram/internals/ScramSaslServerTest.java index 1393b26f87a..94b95b0cfdf 100644 --- a/clients/src/test/java/org/apache/kafka/common/security/scram/internals/ScramSaslServerTest.java +++ b/clients/src/test/java/org/apache/kafka/common/security/scram/internals/ScramSaslServerTest.java @@ -20,14 +20,23 @@ package org.apache.kafka.common.security.scram.internals; import org.apache.kafka.common.errors.SaslAuthenticationException; import org.apache.kafka.common.security.authenticator.CredentialCache; import org.apache.kafka.common.security.scram.ScramCredential; +import org.apache.kafka.common.security.scram.internals.ScramMessages.ClientFirstMessage; +import org.apache.kafka.common.security.scram.internals.ScramMessages.ServerFinalMessage; +import org.apache.kafka.common.security.scram.internals.ScramMessages.ServerFirstMessage; import org.apache.kafka.common.security.token.delegation.internals.DelegationTokenCache; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.Mockito; import java.nio.charset.StandardCharsets; +import java.util.Base64; import java.util.HashMap; +import javax.security.sasl.SaslException; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -67,10 +76,69 @@ public class ScramSaslServerTest { assertThrows(SaslAuthenticationException.class, () -> saslServer.evaluateResponse(clientFirstMessage(USER_A, USER_B))); } + /** + * Validate that server responds with client's nonce as prefix of its nonce in the + * server first message. + *
+ * In addition, it checks that the client final message has nonce that it sent in its + * first message. + */ + @Test + public void validateNonceExchange() throws SaslException { + ScramSaslServer spySaslServer = Mockito.spy(saslServer); + byte[] clientFirstMsgBytes = clientFirstMessage(USER_A, USER_A); + ClientFirstMessage clientFirstMessage = new ClientFirstMessage(clientFirstMsgBytes); + + byte[] serverFirstMsgBytes = spySaslServer.evaluateResponse(clientFirstMsgBytes); + ServerFirstMessage serverFirstMessage = new ServerFirstMessage(serverFirstMsgBytes); + assertTrue(serverFirstMessage.nonce().startsWith(clientFirstMessage.nonce()), + "Nonce in server message should start with client first message's nonce"); + + byte[] clientFinalMessage = clientFinalMessage(serverFirstMessage.nonce()); + Mockito.doNothing() + .when(spySaslServer).verifyClientProof(Mockito.any(ScramMessages.ClientFinalMessage.class)); + byte[] serverFinalMsgBytes = spySaslServer.evaluateResponse(clientFinalMessage); + ServerFinalMessage serverFinalMessage = new ServerFinalMessage(serverFinalMsgBytes); + assertNull(serverFinalMessage.error(), "Server final message should not contain error"); + } + + @Test + public void validateFailedNonceExchange() throws SaslException { + ScramSaslServer spySaslServer = Mockito.spy(saslServer); + byte[] clientFirstMsgBytes = clientFirstMessage(USER_A, USER_A); + ClientFirstMessage clientFirstMessage = new ClientFirstMessage(clientFirstMsgBytes); + + byte[] serverFirstMsgBytes = spySaslServer.evaluateResponse(clientFirstMsgBytes); + ServerFirstMessage serverFirstMessage = new ServerFirstMessage(serverFirstMsgBytes); + assertTrue(serverFirstMessage.nonce().startsWith(clientFirstMessage.nonce()), + "Nonce in server message should start with client first message's nonce"); + + byte[] clientFinalMessage = clientFinalMessage(formatter.secureRandomString()); + Mockito.doNothing() + .when(spySaslServer).verifyClientProof(Mockito.any(ScramMessages.ClientFinalMessage.class)); + SaslException saslException = assertThrows(SaslException.class, + () -> spySaslServer.evaluateResponse(clientFinalMessage)); + assertEquals("Invalid client nonce in the final client message.", + saslException.getMessage(), + "Failure message: " + saslException.getMessage()); + } + private byte[] clientFirstMessage(String userName, String authorizationId) { String nonce = formatter.secureRandomString(); String authorizationField = authorizationId != null ? "a=" + authorizationId : ""; String firstMessage = String.format("n,%s,n=%s,r=%s", authorizationField, userName, nonce); return firstMessage.getBytes(StandardCharsets.UTF_8); } + + private byte[] clientFinalMessage(String nonce) { + String channelBinding = randomBytesAsString(); + String proof = randomBytesAsString(); + + String message = String.format("c=%s,r=%s,p=%s", channelBinding, nonce, proof); + return message.getBytes(StandardCharsets.UTF_8); + } + + private String randomBytesAsString() { + return Base64.getEncoder().encodeToString(formatter.secureRandomBytes()); + } }