diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginCallbackHandler.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginCallbackHandler.java index fc9e6896115..0d8701ba11d 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginCallbackHandler.java +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginCallbackHandler.java @@ -24,12 +24,13 @@ import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; import org.apache.kafka.common.security.auth.SaslExtensions; import org.apache.kafka.common.security.auth.SaslExtensionsCallback; import org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerClientInitialResponse; -import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenRetriever; -import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenRetrieverFactory; -import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidator; -import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidatorFactory; +import org.apache.kafka.common.security.oauthbearer.internals.secured.DefaultJwtRetriever; +import org.apache.kafka.common.security.oauthbearer.internals.secured.DefaultJwtValidator; import org.apache.kafka.common.security.oauthbearer.internals.secured.JaasOptionsUtils; +import org.apache.kafka.common.security.oauthbearer.internals.secured.JwtRetriever; +import org.apache.kafka.common.security.oauthbearer.internals.secured.JwtValidator; import org.apache.kafka.common.security.oauthbearer.internals.secured.ValidateException; +import org.apache.kafka.common.utils.Utils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -179,55 +180,48 @@ public class OAuthBearerLoginCallbackHandler implements AuthenticateCallbackHand private Map moduleOptions; - private AccessTokenRetriever accessTokenRetriever; + private JwtRetriever jwtRetriever; - private AccessTokenValidator accessTokenValidator; - - private boolean isInitialized = false; + private JwtValidator jwtValidator; @Override public void configure(Map configs, String saslMechanism, List jaasConfigEntries) { - moduleOptions = JaasOptionsUtils.getOptions(saslMechanism, jaasConfigEntries); - AccessTokenRetriever accessTokenRetriever = AccessTokenRetrieverFactory.create(configs, saslMechanism, moduleOptions); - AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(configs, saslMechanism); - init(accessTokenRetriever, accessTokenValidator); - } - - public void init(AccessTokenRetriever accessTokenRetriever, AccessTokenValidator accessTokenValidator) { - this.accessTokenRetriever = accessTokenRetriever; - this.accessTokenValidator = accessTokenValidator; - - try { - this.accessTokenRetriever.init(); - } catch (IOException e) { - throw new KafkaException("The OAuth login configuration encountered an error when initializing the AccessTokenRetriever", e); - } - - isInitialized = true; + Map moduleOptions = JaasOptionsUtils.getOptions(saslMechanism, jaasConfigEntries); + JwtRetriever jwtRetriever = new DefaultJwtRetriever(configs, saslMechanism, moduleOptions); + JwtValidator jwtValidator = new DefaultJwtValidator(configs, saslMechanism); + init(moduleOptions, jwtRetriever, jwtValidator); } /* * Package-visible for testing. */ + void init(Map moduleOptions, JwtRetriever jwtRetriever, JwtValidator jwtValidator) { + this.moduleOptions = moduleOptions; + this.jwtRetriever = jwtRetriever; + this.jwtValidator = jwtValidator; - AccessTokenRetriever getAccessTokenRetriever() { - return accessTokenRetriever; - } + try { + this.jwtRetriever.init(); + } catch (IOException e) { + throw new KafkaException("The OAuth login callback encountered an error when initializing the JwtRetriever", e); + } - @Override - public void close() { - if (accessTokenRetriever != null) { - try { - this.accessTokenRetriever.close(); - } catch (IOException e) { - log.warn("The OAuth login configuration encountered an error when closing the AccessTokenRetriever", e); - } + try { + this.jwtValidator.init(); + } catch (IOException e) { + throw new KafkaException("The OAuth login callback encountered an error when initializing the JwtValidator", e); } } + @Override + public void close() { + Utils.closeQuietly(jwtRetriever, "JWT retriever"); + Utils.closeQuietly(jwtValidator, "JWT validator"); + } + @Override public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { - checkInitialized(); + checkConfigured(); for (Callback callback : callbacks) { if (callback instanceof OAuthBearerTokenCallback) { @@ -241,11 +235,11 @@ public class OAuthBearerLoginCallbackHandler implements AuthenticateCallbackHand } private void handleTokenCallback(OAuthBearerTokenCallback callback) throws IOException { - checkInitialized(); - String accessToken = accessTokenRetriever.retrieve(); + checkConfigured(); + String accessToken = jwtRetriever.retrieve(); try { - OAuthBearerToken token = accessTokenValidator.validate(accessToken); + OAuthBearerToken token = jwtValidator.validate(accessToken); callback.token(token); } catch (ValidateException e) { log.warn(e.getMessage(), e); @@ -254,7 +248,7 @@ public class OAuthBearerLoginCallbackHandler implements AuthenticateCallbackHand } private void handleExtensionsCallback(SaslExtensionsCallback callback) { - checkInitialized(); + checkConfigured(); Map extensions = new HashMap<>(); @@ -286,9 +280,9 @@ public class OAuthBearerLoginCallbackHandler implements AuthenticateCallbackHand callback.extensions(saslExtensions); } - private void checkInitialized() { - if (!isInitialized) - throw new IllegalStateException(String.format("To use %s, first call the configure or init method", getClass().getSimpleName())); + private void checkConfigured() { + if (moduleOptions == null || jwtRetriever == null || jwtValidator == null) + throw new IllegalStateException(String.format("To use %s, first call the configure method", getClass().getSimpleName())); } } diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerValidatorCallbackHandler.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerValidatorCallbackHandler.java index f9422370db1..c10b7db4e24 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerValidatorCallbackHandler.java +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerValidatorCallbackHandler.java @@ -19,13 +19,14 @@ package org.apache.kafka.common.security.oauthbearer; import org.apache.kafka.common.KafkaException; import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; -import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidator; -import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidatorFactory; import org.apache.kafka.common.security.oauthbearer.internals.secured.CloseableVerificationKeyResolver; +import org.apache.kafka.common.security.oauthbearer.internals.secured.DefaultJwtValidator; import org.apache.kafka.common.security.oauthbearer.internals.secured.JaasOptionsUtils; +import org.apache.kafka.common.security.oauthbearer.internals.secured.JwtValidator; import org.apache.kafka.common.security.oauthbearer.internals.secured.RefreshingHttpsJwksVerificationKeyResolver; import org.apache.kafka.common.security.oauthbearer.internals.secured.ValidateException; import org.apache.kafka.common.security.oauthbearer.internals.secured.VerificationKeyResolverFactory; +import org.apache.kafka.common.utils.Utils; import org.jose4j.jws.JsonWebSignature; import org.jose4j.jwx.JsonWebStructure; @@ -119,9 +120,7 @@ public class OAuthBearerValidatorCallbackHandler implements AuthenticateCallback private CloseableVerificationKeyResolver verificationKeyResolver; - private AccessTokenValidator accessTokenValidator; - - private boolean isInitialized = false; + private JwtValidator jwtValidator; @Override public void configure(Map configs, String saslMechanism, List jaasConfigEntries) { @@ -135,37 +134,39 @@ public class OAuthBearerValidatorCallbackHandler implements AuthenticateCallback new RefCountingVerificationKeyResolver(VerificationKeyResolverFactory.create(configs, saslMechanism, moduleOptions))); } - AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(configs, saslMechanism, verificationKeyResolver); - init(verificationKeyResolver, accessTokenValidator); + JwtValidator jwtValidator = new DefaultJwtValidator(configs, saslMechanism, verificationKeyResolver); + init(verificationKeyResolver, jwtValidator); } - public void init(CloseableVerificationKeyResolver verificationKeyResolver, AccessTokenValidator accessTokenValidator) { + /* + * Package-visible for testing. + */ + void init(CloseableVerificationKeyResolver verificationKeyResolver, JwtValidator jwtValidator) { this.verificationKeyResolver = verificationKeyResolver; - this.accessTokenValidator = accessTokenValidator; + this.jwtValidator = jwtValidator; try { verificationKeyResolver.init(); } catch (Exception e) { - throw new KafkaException("The OAuth validator configuration encountered an error when initializing the VerificationKeyResolver", e); + throw new KafkaException("The OAuth validator callback encountered an error when initializing the VerificationKeyResolver", e); } - isInitialized = true; + try { + jwtValidator.init(); + } catch (IOException e) { + throw new KafkaException("The OAuth validator callback encountered an error when initializing the JwtValidator", e); + } } @Override public void close() { - if (verificationKeyResolver != null) { - try { - verificationKeyResolver.close(); - } catch (Exception e) { - log.error(e.getMessage(), e); - } - } + Utils.closeQuietly(jwtValidator, "JWT validator"); + Utils.closeQuietly(verificationKeyResolver, "JWT verification key resolver"); } @Override public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { - checkInitialized(); + checkConfigured(); for (Callback callback : callbacks) { if (callback instanceof OAuthBearerValidatorCallback) { @@ -179,12 +180,12 @@ public class OAuthBearerValidatorCallbackHandler implements AuthenticateCallback } private void handleValidatorCallback(OAuthBearerValidatorCallback callback) { - checkInitialized(); + checkConfigured(); OAuthBearerToken token; try { - token = accessTokenValidator.validate(callback.tokenValue()); + token = jwtValidator.validate(callback.tokenValue()); callback.token(token); } catch (ValidateException e) { log.warn(e.getMessage(), e); @@ -193,14 +194,14 @@ public class OAuthBearerValidatorCallbackHandler implements AuthenticateCallback } private void handleExtensionsValidatorCallback(OAuthBearerExtensionsValidatorCallback extensionsValidatorCallback) { - checkInitialized(); + checkConfigured(); extensionsValidatorCallback.inputExtensions().map().forEach((extensionName, v) -> extensionsValidatorCallback.valid(extensionName)); } - private void checkInitialized() { - if (!isInitialized) - throw new IllegalStateException(String.format("To use %s, first call the configure or init method", getClass().getSimpleName())); + private void checkConfigured() { + if (verificationKeyResolver == null || jwtValidator == null) + throw new IllegalStateException(String.format("To use %s, first call the configure method", getClass().getSimpleName())); } /** diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/AccessTokenValidatorFactory.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/AccessTokenValidatorFactory.java deleted file mode 100644 index e4b39e5cc53..00000000000 --- a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/AccessTokenValidatorFactory.java +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.kafka.common.security.oauthbearer.internals.secured; - -import org.jose4j.keys.resolvers.VerificationKeyResolver; - -import java.util.List; -import java.util.Map; -import java.util.Set; - -import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_CLOCK_SKEW_SECONDS; -import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_EXPECTED_AUDIENCE; -import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_EXPECTED_ISSUER; -import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_SCOPE_CLAIM_NAME; -import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_SUB_CLAIM_NAME; - -public class AccessTokenValidatorFactory { - - public static AccessTokenValidator create(Map configs) { - return create(configs, (String) null); - } - - public static AccessTokenValidator create(Map configs, String saslMechanism) { - ConfigurationUtils cu = new ConfigurationUtils(configs, saslMechanism); - String scopeClaimName = cu.get(SASL_OAUTHBEARER_SCOPE_CLAIM_NAME); - String subClaimName = cu.get(SASL_OAUTHBEARER_SUB_CLAIM_NAME); - return new LoginAccessTokenValidator(scopeClaimName, subClaimName); - } - - public static AccessTokenValidator create(Map configs, - VerificationKeyResolver verificationKeyResolver) { - return create(configs, null, verificationKeyResolver); - } - - public static AccessTokenValidator create(Map configs, - String saslMechanism, - VerificationKeyResolver verificationKeyResolver) { - ConfigurationUtils cu = new ConfigurationUtils(configs, saslMechanism); - Set expectedAudiences = null; - List l = cu.get(SASL_OAUTHBEARER_EXPECTED_AUDIENCE); - - if (l != null) - expectedAudiences = Set.copyOf(l); - - Integer clockSkew = cu.validateInteger(SASL_OAUTHBEARER_CLOCK_SKEW_SECONDS, false); - String expectedIssuer = cu.validateString(SASL_OAUTHBEARER_EXPECTED_ISSUER, false); - String scopeClaimName = cu.validateString(SASL_OAUTHBEARER_SCOPE_CLAIM_NAME); - String subClaimName = cu.validateString(SASL_OAUTHBEARER_SUB_CLAIM_NAME); - - return new ValidatorAccessTokenValidator(clockSkew, - expectedAudiences, - expectedIssuer, - verificationKeyResolver, - scopeClaimName, - subClaimName); - } - -} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/ValidatorAccessTokenValidator.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/BrokerJwtValidator.java similarity index 93% rename from clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/ValidatorAccessTokenValidator.java rename to clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/BrokerJwtValidator.java index c7ae8edae9d..74ad4765222 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/ValidatorAccessTokenValidator.java +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/BrokerJwtValidator.java @@ -38,7 +38,7 @@ import java.util.Set; import static org.jose4j.jwa.AlgorithmConstraints.DISALLOW_NONE; /** - * ValidatorAccessTokenValidator is an implementation of {@link AccessTokenValidator} that is used + * {@code BrokerJwtValidator} is an implementation of {@link JwtValidator} that is used * by the broker to perform more extensive validation of the JWT access token that is received * from the client, but ultimately from posting the client credentials to the OAuth/OIDC provider's * token endpoint. @@ -62,9 +62,9 @@ import static org.jose4j.jwa.AlgorithmConstraints.DISALLOW_NONE; * */ -public class ValidatorAccessTokenValidator implements AccessTokenValidator { +public class BrokerJwtValidator implements JwtValidator { - private static final Logger log = LoggerFactory.getLogger(ValidatorAccessTokenValidator.class); + private static final Logger log = LoggerFactory.getLogger(BrokerJwtValidator.class); private final JwtConsumer jwtConsumer; @@ -73,7 +73,7 @@ public class ValidatorAccessTokenValidator implements AccessTokenValidator { private final String subClaimName; /** - * Creates a new ValidatorAccessTokenValidator that will be used by the broker for more + * Creates a new {@code BrokerJwtValidator} that will be used by the broker for more * thorough validation of the JWT. * * @param clockSkew The optional value (in seconds) to allow for differences @@ -112,12 +112,12 @@ public class ValidatorAccessTokenValidator implements AccessTokenValidator { * @see VerificationKeyResolver */ - public ValidatorAccessTokenValidator(Integer clockSkew, - Set expectedAudiences, - String expectedIssuer, - VerificationKeyResolver verificationKeyResolver, - String scopeClaimName, - String subClaimName) { + public BrokerJwtValidator(Integer clockSkew, + Set expectedAudiences, + String expectedIssuer, + VerificationKeyResolver verificationKeyResolver, + String scopeClaimName, + String subClaimName) { final JwtConsumerBuilder jwtConsumerBuilder = new JwtConsumerBuilder(); if (clockSkew != null) diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/LoginAccessTokenValidator.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/ClientJwtValidator.java similarity index 90% rename from clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/LoginAccessTokenValidator.java rename to clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/ClientJwtValidator.java index 773311ff0ab..1dee4671d39 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/LoginAccessTokenValidator.java +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/ClientJwtValidator.java @@ -33,7 +33,7 @@ import static org.apache.kafka.common.config.SaslConfigs.DEFAULT_SASL_OAUTHBEARE import static org.apache.kafka.common.config.SaslConfigs.DEFAULT_SASL_OAUTHBEARER_SUB_CLAIM_NAME; /** - * LoginAccessTokenValidator is an implementation of {@link AccessTokenValidator} that is used + * {@code ClientJwtValidator} is an implementation of {@link JwtValidator} that is used * by the client to perform some rudimentary validation of the JWT access token that is received * as part of the response from posting the client credentials to the OAuth/OIDC provider's * token endpoint. @@ -46,13 +46,13 @@ import static org.apache.kafka.common.config.SaslConfigs.DEFAULT_SASL_OAUTHBEARE * RFC 6750 Section 2.1 * *
  • Basic conversion of the token into an in-memory map
  • - *
  • Presence of scope, exp, subject, and iat claims
  • + *
  • Presence of scope, exp, subject, and iat claims
  • * */ -public class LoginAccessTokenValidator implements AccessTokenValidator { +public class ClientJwtValidator implements JwtValidator { - private static final Logger log = LoggerFactory.getLogger(LoginAccessTokenValidator.class); + private static final Logger log = LoggerFactory.getLogger(ClientJwtValidator.class); public static final String EXPIRATION_CLAIM_NAME = "exp"; @@ -63,14 +63,14 @@ public class LoginAccessTokenValidator implements AccessTokenValidator { private final String subClaimName; /** - * Creates a new LoginAccessTokenValidator that will be used by the client for lightweight + * Creates a new {@code ClientJwtValidator} that will be used by the client for lightweight * validation of the JWT. * * @param scopeClaimName Name of the scope claim to use; must be non-null * @param subClaimName Name of the subject claim to use; must be non-null */ - public LoginAccessTokenValidator(String scopeClaimName, String subClaimName) { + public ClientJwtValidator(String scopeClaimName, String subClaimName) { this.scopeClaimName = ClaimValidationUtils.validateClaimNameOverride(DEFAULT_SASL_OAUTHBEARER_SCOPE_CLAIM_NAME, scopeClaimName); this.subClaimName = ClaimValidationUtils.validateClaimNameOverride(DEFAULT_SASL_OAUTHBEARER_SUB_CLAIM_NAME, subClaimName); } diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/AccessTokenRetrieverFactory.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/DefaultJwtRetriever.java similarity index 68% rename from clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/AccessTokenRetrieverFactory.java rename to clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/DefaultJwtRetriever.java index 0ed4a1a2303..2d607ddcda8 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/AccessTokenRetrieverFactory.java +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/DefaultJwtRetriever.java @@ -18,7 +18,9 @@ package org.apache.kafka.common.security.oauthbearer.internals.secured; import org.apache.kafka.common.config.SaslConfigs; +import org.apache.kafka.common.utils.Utils; +import java.io.IOException; import java.net.URL; import java.util.Locale; import java.util.Map; @@ -36,32 +38,33 @@ import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallb import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler.CLIENT_SECRET_CONFIG; import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler.SCOPE_CONFIG; -public class AccessTokenRetrieverFactory { +/** + * {@code DefaultJwtRetriever} instantiates and delegates {@link JwtRetriever} API calls to an embedded implementation + * based on configuration. If {@link SaslConfigs#SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL} is configured with a + * {@code file}-based URL, a {@link FileJwtRetriever} is created and the JWT is expected be contained in the file + * specified. Otherwise, it's assumed to be an HTTP/HTTPS-based URL, so an {@link HttpJwtRetriever} is created. + */ +public class DefaultJwtRetriever implements JwtRetriever { - /** - * Create an {@link AccessTokenRetriever} from the given SASL and JAAS configuration. - * - * Note: the returned AccessTokenRetriever is not initialized - * here and must be done by the caller prior to use. - * - * @param configs SASL configuration - * @param jaasConfig JAAS configuration - * - * @return Non-null {@link AccessTokenRetriever} - */ + private final Map configs; + private final String saslMechanism; + private final Map jaasConfig; - public static AccessTokenRetriever create(Map configs, Map jaasConfig) { - return create(configs, null, jaasConfig); + private JwtRetriever delegate; + + public DefaultJwtRetriever(Map configs, String saslMechanism, Map jaasConfig) { + this.configs = configs; + this.saslMechanism = saslMechanism; + this.jaasConfig = jaasConfig; } - public static AccessTokenRetriever create(Map configs, - String saslMechanism, - Map jaasConfig) { + @Override + public void init() throws IOException { ConfigurationUtils cu = new ConfigurationUtils(configs, saslMechanism); URL tokenEndpointUrl = cu.validateUrl(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL); if (tokenEndpointUrl.getProtocol().toLowerCase(Locale.ROOT).equals("file")) { - return new FileTokenRetriever(cu.validateFile(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL)); + delegate = new FileJwtRetriever(cu.validateFile(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL)); } else { JaasOptionsUtils jou = new JaasOptionsUtils(jaasConfig); String clientId = jou.validateString(CLIENT_ID_CONFIG); @@ -75,7 +78,7 @@ public class AccessTokenRetrieverFactory { boolean urlencodeHeader = validateUrlencodeHeader(cu); - return new HttpAccessTokenRetriever(clientId, + delegate = new HttpJwtRetriever(clientId, clientSecret, scope, sslSocketFactory, @@ -86,6 +89,21 @@ public class AccessTokenRetrieverFactory { cu.validateInteger(SASL_LOGIN_READ_TIMEOUT_MS, false), urlencodeHeader); } + + delegate.init(); + } + + @Override + public String retrieve() throws IOException { + if (delegate == null) + throw new IllegalStateException("JWT retriever delegate is null; please call init() first"); + + return delegate.retrieve(); + } + + @Override + public void close() throws IOException { + Utils.closeQuietly(delegate, "JWT retriever delegate"); } /** @@ -96,10 +114,10 @@ public class AccessTokenRetrieverFactory { *

    * * This utility method ensures that we have a non-{@code null} value to use in the - * {@link HttpAccessTokenRetriever} constructor. + * {@link HttpJwtRetriever} constructor. */ static boolean validateUrlencodeHeader(ConfigurationUtils configurationUtils) { - Boolean urlencodeHeader = configurationUtils.validateBoolean(SASL_OAUTHBEARER_HEADER_URLENCODE, false); + Boolean urlencodeHeader = configurationUtils.get(SASL_OAUTHBEARER_HEADER_URLENCODE); if (urlencodeHeader != null) return urlencodeHeader; @@ -107,4 +125,7 @@ public class AccessTokenRetrieverFactory { return DEFAULT_SASL_OAUTHBEARER_HEADER_URLENCODE; } + JwtRetriever delegate() { + return delegate; + } } \ No newline at end of file diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/DefaultJwtValidator.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/DefaultJwtValidator.java new file mode 100644 index 00000000000..5cd1e61db88 --- /dev/null +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/DefaultJwtValidator.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.internals.secured; + +import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; +import org.apache.kafka.common.utils.Utils; + +import org.jose4j.keys.resolvers.VerificationKeyResolver; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_CLOCK_SKEW_SECONDS; +import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_EXPECTED_AUDIENCE; +import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_EXPECTED_ISSUER; +import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_SCOPE_CLAIM_NAME; +import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_SUB_CLAIM_NAME; + +/** + * This {@link JwtValidator} uses the delegation approach, instantiating and delegating calls to a + * more concrete implementation. The underlying implementation is determined by the presence/absence + * of the {@link VerificationKeyResolver}: if it's present, a {@link BrokerJwtValidator} is + * created, otherwise a {@link ClientJwtValidator} is created. + */ +public class DefaultJwtValidator implements JwtValidator { + + private final Map configs; + private final String saslMechanism; + private final Optional verificationKeyResolver; + + private JwtValidator delegate; + + public DefaultJwtValidator(Map configs, String saslMechanism) { + this.configs = configs; + this.saslMechanism = saslMechanism; + this.verificationKeyResolver = Optional.empty(); + } + + public DefaultJwtValidator(Map configs, + String saslMechanism, + VerificationKeyResolver verificationKeyResolver) { + this.configs = configs; + this.saslMechanism = saslMechanism; + this.verificationKeyResolver = Optional.of(verificationKeyResolver); + } + + @Override + public void init() throws IOException { + ConfigurationUtils cu = new ConfigurationUtils(configs, saslMechanism); + + if (verificationKeyResolver.isPresent()) { + List expectedAudiencesList = cu.get(SASL_OAUTHBEARER_EXPECTED_AUDIENCE); + Set expectedAudiences = expectedAudiencesList != null ? Set.copyOf(expectedAudiencesList) : null; + Integer clockSkew = cu.validateInteger(SASL_OAUTHBEARER_CLOCK_SKEW_SECONDS, false); + String expectedIssuer = cu.validateString(SASL_OAUTHBEARER_EXPECTED_ISSUER, false); + String scopeClaimName = cu.validateString(SASL_OAUTHBEARER_SCOPE_CLAIM_NAME); + String subClaimName = cu.validateString(SASL_OAUTHBEARER_SUB_CLAIM_NAME); + + delegate = new BrokerJwtValidator(clockSkew, + expectedAudiences, + expectedIssuer, + verificationKeyResolver.get(), + scopeClaimName, + subClaimName); + } else { + String scopeClaimName = cu.get(SASL_OAUTHBEARER_SCOPE_CLAIM_NAME); + String subClaimName = cu.get(SASL_OAUTHBEARER_SUB_CLAIM_NAME); + delegate = new ClientJwtValidator(scopeClaimName, subClaimName); + } + + delegate.init(); + } + + @Override + public OAuthBearerToken validate(String accessToken) throws ValidateException { + if (delegate == null) + throw new IllegalStateException("JWT validator delegate is null; please call init() first"); + + return delegate.validate(accessToken); + } + + @Override + public void close() throws IOException { + Utils.closeQuietly(delegate, "JWT validator delegate"); + } + + JwtValidator delegate() { + return delegate; + } +} diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/FileTokenRetriever.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/FileJwtRetriever.java similarity index 83% rename from clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/FileTokenRetriever.java rename to clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/FileJwtRetriever.java index c145cf75969..f04b5600168 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/FileTokenRetriever.java +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/FileJwtRetriever.java @@ -23,19 +23,19 @@ import java.io.IOException; import java.nio.file.Path; /** - * FileTokenRetriever is an {@link AccessTokenRetriever} that will load the contents, - * interpreting them as a JWT access key in the serialized form. + * FileJwtRetriever is an {@link JwtRetriever} that will load the contents + * of a file, interpreting them as a JWT access key in the serialized form. * - * @see AccessTokenRetriever + * @see JwtRetriever */ -public class FileTokenRetriever implements AccessTokenRetriever { +public class FileJwtRetriever implements JwtRetriever { private final Path accessTokenFile; private String accessToken; - public FileTokenRetriever(Path accessTokenFile) { + public FileJwtRetriever(Path accessTokenFile) { this.accessTokenFile = accessTokenFile; } diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/HttpAccessTokenRetriever.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/HttpJwtRetriever.java similarity index 94% rename from clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/HttpAccessTokenRetriever.java rename to clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/HttpJwtRetriever.java index fdc5707278a..35d25564bc0 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/HttpAccessTokenRetriever.java +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/HttpJwtRetriever.java @@ -49,22 +49,14 @@ import javax.net.ssl.HttpsURLConnection; import javax.net.ssl.SSLSocketFactory; /** - * HttpAccessTokenRetriever is an {@link AccessTokenRetriever} that will - * communicate with an OAuth/OIDC provider directly via HTTP to post client credentials + * HttpJwtRetriever is a {@link JwtRetriever} that will communicate with an OAuth/OIDC + * provider directly via HTTP to post client credentials * ({@link OAuthBearerLoginCallbackHandler#CLIENT_ID_CONFIG}/{@link OAuthBearerLoginCallbackHandler#CLIENT_SECRET_CONFIG}) - * to a publicized token endpoint URL - * ({@link SaslConfigs#SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL}). - * - * @see AccessTokenRetriever - * @see OAuthBearerLoginCallbackHandler#CLIENT_ID_CONFIG - * @see OAuthBearerLoginCallbackHandler#CLIENT_SECRET_CONFIG - * @see OAuthBearerLoginCallbackHandler#SCOPE_CONFIG - * @see SaslConfigs#SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL + * to a publicized token endpoint URL ({@link SaslConfigs#SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL}). */ +public class HttpJwtRetriever implements JwtRetriever { -public class HttpAccessTokenRetriever implements AccessTokenRetriever { - - private static final Logger log = LoggerFactory.getLogger(HttpAccessTokenRetriever.class); + private static final Logger log = LoggerFactory.getLogger(HttpJwtRetriever.class); private static final Set UNRETRYABLE_HTTP_CODES; @@ -117,16 +109,16 @@ public class HttpAccessTokenRetriever implements AccessTokenRetriever { private final boolean urlencodeHeader; - public HttpAccessTokenRetriever(String clientId, - String clientSecret, - String scope, - SSLSocketFactory sslSocketFactory, - String tokenEndpointUrl, - long loginRetryBackoffMs, - long loginRetryBackoffMaxMs, - Integer loginConnectTimeoutMs, - Integer loginReadTimeoutMs, - boolean urlencodeHeader) { + public HttpJwtRetriever(String clientId, + String clientSecret, + String scope, + SSLSocketFactory sslSocketFactory, + String tokenEndpointUrl, + long loginRetryBackoffMs, + long loginRetryBackoffMaxMs, + Integer loginConnectTimeoutMs, + Integer loginReadTimeoutMs, + boolean urlencodeHeader) { this.clientId = Objects.requireNonNull(clientId); this.clientSecret = Objects.requireNonNull(clientSecret); this.scope = scope; diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/Initable.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/Initable.java index 0a38f2b5094..eff1b543886 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/Initable.java +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/Initable.java @@ -22,8 +22,8 @@ import java.io.IOException; public interface Initable { /** - * Lifecycle method to perform any one-time initialization of the retriever. This must - * be performed by the caller to ensure the correct state before methods are invoked. + * Lifecycle method to perform any one-time initialization of a given resource. This must + * be invoked by the caller to ensure the correct state before methods are invoked. * * @throws IOException Thrown on errors related to IO during initialization */ @@ -31,5 +31,4 @@ public interface Initable { default void init() throws IOException { // This method left intentionally blank. } - } diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/AccessTokenRetriever.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/JwtRetriever.java similarity index 88% rename from clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/AccessTokenRetriever.java rename to clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/JwtRetriever.java index 080ea4515b4..b8991250df0 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/AccessTokenRetriever.java +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/JwtRetriever.java @@ -21,20 +21,20 @@ import java.io.Closeable; import java.io.IOException; /** - * An AccessTokenRetriever is the internal API by which the login module will + * A JwtRetriever is the internal API by which the login module will * retrieve an access token for use in authorization by the broker. The implementation may * involve authentication to a remote system, or it can be as simple as loading the contents * of a file or configuration setting. * * Retrieval is a separate concern from validation, so it isn't necessary for - * the AccessTokenRetriever implementation to validate the integrity of the JWT + * the JwtRetriever implementation to validate the integrity of the JWT * access token. * - * @see HttpAccessTokenRetriever - * @see FileTokenRetriever + * @see HttpJwtRetriever + * @see FileJwtRetriever */ -public interface AccessTokenRetriever extends Initable, Closeable { +public interface JwtRetriever extends Initable, Closeable { /** * Retrieves a JWT access token in its serialized three-part form. The implementation diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/AccessTokenValidator.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/JwtValidator.java similarity index 80% rename from clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/AccessTokenValidator.java rename to clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/JwtValidator.java index 0b107a09bc0..82ba10652a1 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/AccessTokenValidator.java +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/JwtValidator.java @@ -19,8 +19,11 @@ package org.apache.kafka.common.security.oauthbearer.internals.secured; import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; +import java.io.Closeable; +import java.io.IOException; + /** - * An instance of AccessTokenValidator acts as a function object that, given an access + * An instance of JwtValidator acts as a function object that, given an access * token in base-64 encoded JWT format, can parse the data, perform validation, and construct an * {@link OAuthBearerToken} for use by the caller. * @@ -40,13 +43,12 @@ import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; *

  • RFC 6750, Section 2.1
  • * * - * @see LoginAccessTokenValidator A basic AccessTokenValidator used by client-side login - * authentication - * @see ValidatorAccessTokenValidator A more robust AccessTokenValidator that is used on the broker - * to validate the token's contents and verify the signature + * @see ClientJwtValidator A basic JwtValidator used by client-side login authentication + * @see BrokerJwtValidator A more robust JwtValidator that is used on the broker to validate the token's + * contents and verify the signature */ -public interface AccessTokenValidator { +public interface JwtValidator extends Initable, Closeable { /** * Accepts an OAuth JWT access token in base-64 encoded format, validates, and returns an @@ -61,4 +63,10 @@ public interface AccessTokenValidator { OAuthBearerToken validate(String accessToken) throws ValidateException; + /** + * Closes any resources that were initialized by {@link #init()}. + */ + default void close() throws IOException { + // Do nothing... + } } diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/RefreshingHttpsJwks.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/RefreshingHttpsJwks.java index 62261fed58d..4d75ff847ea 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/RefreshingHttpsJwks.java +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/RefreshingHttpsJwks.java @@ -49,12 +49,12 @@ import java.util.concurrent.locks.ReentrantReadWriteLock; * This instance is created and provided to the * {@link org.jose4j.keys.resolvers.HttpsJwksVerificationKeyResolver} that is used when using * an HTTP-/HTTPS-based {@link org.jose4j.keys.resolvers.VerificationKeyResolver}, which is then - * provided to the {@link ValidatorAccessTokenValidator} to use in validating the signature of + * provided to the {@link BrokerJwtValidator} to use in validating the signature of * a JWT. * * @see org.jose4j.keys.resolvers.HttpsJwksVerificationKeyResolver * @see org.jose4j.keys.resolvers.VerificationKeyResolver - * @see ValidatorAccessTokenValidator + * @see BrokerJwtValidator */ public final class RefreshingHttpsJwks implements Initable, Closeable { diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/ValidateException.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/ValidateException.java index 430b9007830..8c107abc831 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/ValidateException.java +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/ValidateException.java @@ -27,7 +27,7 @@ import javax.security.auth.callback.Callback; * processing of a {@link javax.security.auth.callback.CallbackHandler#handle(Callback[])}. * This error, however, is not thrown from that method directly. * - * @see AccessTokenValidator#validate(String) + * @see JwtValidator#validate(String) */ public class ValidateException extends KafkaException { diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/VerificationKeyResolverFactory.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/VerificationKeyResolverFactory.java index 0422045fc02..c9ad41d5a97 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/VerificationKeyResolverFactory.java +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/internals/secured/VerificationKeyResolverFactory.java @@ -37,7 +37,7 @@ import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_JWKS_E public class VerificationKeyResolverFactory { /** - * Create an {@link AccessTokenRetriever} from the given + * Create a {@link JwtRetriever} from the given * {@link org.apache.kafka.common.config.SaslConfigs}. * * Note: the returned CloseableVerificationKeyResolver is not diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginCallbackHandlerTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginCallbackHandlerTest.java index 5b1b2976662..290c58d6553 100644 --- a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginCallbackHandlerTest.java +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerLoginCallbackHandlerTest.java @@ -21,13 +21,12 @@ import org.apache.kafka.common.config.ConfigException; import org.apache.kafka.common.security.auth.SaslExtensionsCallback; import org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerClientInitialResponse; import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenBuilder; -import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenRetriever; -import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidator; -import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidatorFactory; -import org.apache.kafka.common.security.oauthbearer.internals.secured.FileTokenRetriever; -import org.apache.kafka.common.security.oauthbearer.internals.secured.HttpAccessTokenRetriever; +import org.apache.kafka.common.security.oauthbearer.internals.secured.DefaultJwtRetriever; +import org.apache.kafka.common.security.oauthbearer.internals.secured.DefaultJwtValidator; +import org.apache.kafka.common.security.oauthbearer.internals.secured.FileJwtRetriever; +import org.apache.kafka.common.security.oauthbearer.internals.secured.JwtRetriever; +import org.apache.kafka.common.security.oauthbearer.internals.secured.JwtValidator; import org.apache.kafka.common.security.oauthbearer.internals.secured.OAuthBearerTest; -import org.apache.kafka.common.utils.Utils; import org.jose4j.jws.AlgorithmIdentifiers; import org.junit.jupiter.api.AfterEach; @@ -35,9 +34,7 @@ import org.junit.jupiter.api.Test; import java.io.File; import java.io.IOException; -import java.util.Base64; import java.util.Calendar; -import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.TimeZone; @@ -50,7 +47,6 @@ import static org.apache.kafka.common.config.internals.BrokerSecurityConfigs.ALL import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler.CLIENT_ID_CONFIG; import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler.CLIENT_SECRET_CONFIG; import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -58,6 +54,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest { + @AfterEach public void tearDown() throws Exception { System.clearProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG); @@ -70,9 +67,10 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest { .jwk(createRsaJwk()) .alg(AlgorithmIdentifiers.RSA_USING_SHA256); String accessToken = builder.build(); - AccessTokenRetriever accessTokenRetriever = () -> accessToken; - - OAuthBearerLoginCallbackHandler handler = createHandler(accessTokenRetriever, configs); + JwtRetriever jwtRetriever = () -> accessToken; + JwtValidator jwtValidator = createJwtValidator(configs); + OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler(); + handler.init(Map.of(), jwtRetriever, jwtValidator); try { OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback(); @@ -91,7 +89,6 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest { @Test public void testHandleSaslExtensionsCallback() throws Exception { - OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler(); Map configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, "http://www.example.com"); System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, "http://www.example.com"); Map jaasConfig = new HashMap<>(); @@ -100,7 +97,11 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest { jaasConfig.put("extension_foo", "1"); jaasConfig.put("extension_bar", 2); jaasConfig.put("EXTENSION_baz", "3"); - configureHandler(handler, configs, jaasConfig); + + JwtRetriever jwtRetriever = createJwtRetriever(configs, jaasConfig); + JwtValidator jwtValidator = createJwtValidator(configs); + OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler(); + handler.init(jaasConfig, jwtRetriever, jwtValidator); try { SaslExtensionsCallback callback = new SaslExtensionsCallback(); @@ -121,14 +122,17 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest { public void testHandleSaslExtensionsCallbackWithInvalidExtension() { String illegalKey = "extension_" + OAuthBearerClientInitialResponse.AUTH_KEY; - OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler(); Map configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, "http://www.example.com"); System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, "http://www.example.com"); Map jaasConfig = new HashMap<>(); jaasConfig.put(CLIENT_ID_CONFIG, "an ID"); jaasConfig.put(CLIENT_SECRET_CONFIG, "a secret"); jaasConfig.put(illegalKey, "this key isn't allowed per OAuthBearerClientInitialResponse.validateExtensions"); - configureHandler(handler, configs, jaasConfig); + + JwtRetriever jwtRetriever = createJwtRetriever(configs, jaasConfig); + JwtValidator jwtValidator = createJwtValidator(configs); + OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler(); + handler.init(jaasConfig, jwtRetriever, jwtValidator); try { SaslExtensionsCallback callback = new SaslExtensionsCallback(); @@ -143,10 +147,10 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest { @Test public void testInvalidCallbackGeneratesUnsupportedCallbackException() { Map configs = getSaslConfigs(); + JwtRetriever jwtRetriever = () -> "test"; + JwtValidator jwtValidator = createJwtValidator(configs); OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler(); - AccessTokenRetriever accessTokenRetriever = () -> "foo"; - AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(configs); - handler.init(accessTokenRetriever, accessTokenValidator); + handler.init(Map.of(), jwtRetriever, jwtValidator); try { Callback unsupportedCallback = new Callback() { }; @@ -166,11 +170,13 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest { @Test public void testMissingAccessToken() { - AccessTokenRetriever accessTokenRetriever = () -> { + Map configs = getSaslConfigs(); + JwtRetriever jwtRetriever = () -> { throw new IOException("The token endpoint response access_token value must be non-null"); }; - Map configs = getSaslConfigs(); - OAuthBearerLoginCallbackHandler handler = createHandler(accessTokenRetriever, configs); + JwtValidator jwtValidator = createJwtValidator(configs); + OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler(); + handler.init(Map.of(), jwtRetriever, jwtValidator); try { OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback(); @@ -196,7 +202,11 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest { File accessTokenFile = createTempFile(tmpDir, "access-token-", ".json", withNewline); Map configs = getSaslConfigs(); - OAuthBearerLoginCallbackHandler handler = createHandler(new FileTokenRetriever(accessTokenFile.toPath()), configs); + JwtRetriever jwtRetriever = new FileJwtRetriever(accessTokenFile.toPath()); + JwtValidator jwtValidator = createJwtValidator(configs); + OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler(); + handler.init(Map.of(), jwtRetriever, jwtValidator); + OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback(); try { handler.handle(new Callback[]{callback}); @@ -211,39 +221,15 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest { @Test public void testNotConfigured() { OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler(); - assertThrowsWithMessage(IllegalStateException.class, () -> handler.handle(new Callback[] {}), "first call the configure or init method"); - } - - @Test - public void testConfigureWithAccessTokenFile() throws Exception { - String expected = "{}"; - - File tmpDir = createTempDir("access-token"); - File accessTokenFile = createTempFile(tmpDir, "access-token-", ".json", expected); - System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, accessTokenFile.toURI().toString()); - - OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler(); - Map configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, accessTokenFile.toURI().toString()); - Map jaasConfigs = Collections.emptyMap(); - configureHandler(handler, configs, jaasConfigs); - assertInstanceOf(FileTokenRetriever.class, handler.getAccessTokenRetriever()); - } - - @Test - public void testConfigureWithAccessClientCredentials() { - OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler(); - Map configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, "http://www.example.com"); - System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, "http://www.example.com"); - Map jaasConfigs = new HashMap<>(); - jaasConfigs.put(CLIENT_ID_CONFIG, "an ID"); - jaasConfigs.put(CLIENT_SECRET_CONFIG, "a secret"); - configureHandler(handler, configs, jaasConfigs); - assertInstanceOf(HttpAccessTokenRetriever.class, handler.getAccessTokenRetriever()); + assertThrowsWithMessage(IllegalStateException.class, () -> handler.handle(new Callback[] {}), "first call the configure method"); } private void testInvalidAccessToken(String accessToken, String expectedMessageSubstring) throws Exception { Map configs = getSaslConfigs(); - OAuthBearerLoginCallbackHandler handler = createHandler(() -> accessToken, configs); + JwtRetriever jwtRetriever = () -> accessToken; + JwtValidator jwtValidator = createJwtValidator(configs); + OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler(); + handler.init(Map.of(), jwtRetriever, jwtValidator); try { OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback(); @@ -260,19 +246,15 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest { } } - private String createAccessKey(String header, String payload, String signature) { - Base64.Encoder enc = Base64.getEncoder(); - header = enc.encodeToString(Utils.utf8(header)); - payload = enc.encodeToString(Utils.utf8(payload)); - signature = enc.encodeToString(Utils.utf8(signature)); - return String.format("%s.%s.%s", header, payload, signature); + private static DefaultJwtRetriever createJwtRetriever(Map configs) { + return createJwtRetriever(configs, Map.of()); } - private OAuthBearerLoginCallbackHandler createHandler(AccessTokenRetriever accessTokenRetriever, Map configs) { - OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler(); - AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(configs); - handler.init(accessTokenRetriever, accessTokenValidator); - return handler; + private static DefaultJwtRetriever createJwtRetriever(Map configs, Map jaasConfigs) { + return new DefaultJwtRetriever(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, jaasConfigs); } + private static DefaultJwtValidator createJwtValidator(Map configs) { + return new DefaultJwtValidator(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM); + } } diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerValidatorCallbackHandlerTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerValidatorCallbackHandlerTest.java index d682a05ec11..0f1315b4281 100644 --- a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerValidatorCallbackHandlerTest.java +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/OAuthBearerValidatorCallbackHandlerTest.java @@ -17,27 +17,30 @@ package org.apache.kafka.common.security.oauthbearer; +import org.apache.kafka.common.KafkaException; import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenBuilder; -import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidator; -import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidatorFactory; import org.apache.kafka.common.security.oauthbearer.internals.secured.CloseableVerificationKeyResolver; +import org.apache.kafka.common.security.oauthbearer.internals.secured.DefaultJwtValidator; +import org.apache.kafka.common.security.oauthbearer.internals.secured.JwtValidator; import org.apache.kafka.common.security.oauthbearer.internals.secured.OAuthBearerTest; -import org.apache.kafka.common.utils.Utils; +import org.apache.kafka.common.security.oauthbearer.internals.secured.ValidateException; import org.jose4j.jws.AlgorithmIdentifiers; import org.junit.jupiter.api.Test; +import java.io.IOException; import java.util.Arrays; -import java.util.Base64; import java.util.List; import java.util.Map; import javax.security.auth.callback.Callback; import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_EXPECTED_AUDIENCE; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; public class OAuthBearerValidatorCallbackHandlerTest extends OAuthBearerTest { @@ -53,7 +56,10 @@ public class OAuthBearerValidatorCallbackHandlerTest extends OAuthBearerTest { String accessToken = builder.build(); Map configs = getSaslConfigs(SASL_OAUTHBEARER_EXPECTED_AUDIENCE, allAudiences); - OAuthBearerValidatorCallbackHandler handler = createHandler(configs, builder); + CloseableVerificationKeyResolver verificationKeyResolver = createVerificationKeyResolver(builder); + JwtValidator jwtValidator = createJwtValidator(configs, verificationKeyResolver); + OAuthBearerValidatorCallbackHandler handler = new OAuthBearerValidatorCallbackHandler(); + handler.init(verificationKeyResolver, jwtValidator); try { OAuthBearerValidatorCallback callback = new OAuthBearerValidatorCallback(accessToken); @@ -81,9 +87,68 @@ public class OAuthBearerValidatorCallbackHandlerTest extends OAuthBearerTest { assertInvalidAccessTokenFails(createAccessKey("{}", "{}", "{}"), substring); } + @Test + public void testHandlerInitThrowsException() throws IOException { + IOException initError = new IOException("init() error"); + + AccessTokenBuilder builder = new AccessTokenBuilder() + .alg(AlgorithmIdentifiers.RSA_USING_SHA256); + CloseableVerificationKeyResolver verificationKeyResolver = createVerificationKeyResolver(builder); + JwtValidator jwtValidator = new JwtValidator() { + @Override + public void init() throws IOException { + throw initError; + } + + @Override + public OAuthBearerToken validate(String accessToken) throws ValidateException { + return null; + } + }; + + OAuthBearerValidatorCallbackHandler handler = new OAuthBearerValidatorCallbackHandler(); + + // An error initializing the JwtValidator should cause OAuthBearerValidatorCallbackHandler.init() to fail. + KafkaException root = assertThrows( + KafkaException.class, + () -> handler.init(verificationKeyResolver, jwtValidator) + ); + assertNotNull(root.getCause()); + assertEquals(initError, root.getCause()); + } + + @Test + public void testHandlerCloseDoesNotThrowException() throws IOException { + AccessTokenBuilder builder = new AccessTokenBuilder() + .alg(AlgorithmIdentifiers.RSA_USING_SHA256); + CloseableVerificationKeyResolver verificationKeyResolver = createVerificationKeyResolver(builder); + JwtValidator jwtValidator = new JwtValidator() { + @Override + public void close() throws IOException { + throw new IOException("close() error"); + } + + @Override + public OAuthBearerToken validate(String accessToken) throws ValidateException { + return null; + } + }; + + OAuthBearerValidatorCallbackHandler handler = new OAuthBearerValidatorCallbackHandler(); + handler.init(verificationKeyResolver, jwtValidator); + + // An error closings the JwtValidator should *not* cause OAuthBearerValidatorCallbackHandler.close() to fail. + assertDoesNotThrow(handler::close); + } + private void assertInvalidAccessTokenFails(String accessToken, String expectedMessageSubstring) throws Exception { + AccessTokenBuilder builder = new AccessTokenBuilder() + .alg(AlgorithmIdentifiers.RSA_USING_SHA256); Map configs = getSaslConfigs(); - OAuthBearerValidatorCallbackHandler handler = createHandler(configs, new AccessTokenBuilder()); + CloseableVerificationKeyResolver verificationKeyResolver = createVerificationKeyResolver(builder); + JwtValidator jwtValidator = createJwtValidator(configs, verificationKeyResolver); + OAuthBearerValidatorCallbackHandler handler = new OAuthBearerValidatorCallbackHandler(); + handler.init(verificationKeyResolver, jwtValidator); try { OAuthBearerValidatorCallback callback = new OAuthBearerValidatorCallback(accessToken); @@ -98,22 +163,11 @@ public class OAuthBearerValidatorCallbackHandlerTest extends OAuthBearerTest { } } - private OAuthBearerValidatorCallbackHandler createHandler(Map options, - AccessTokenBuilder builder) { - OAuthBearerValidatorCallbackHandler handler = new OAuthBearerValidatorCallbackHandler(); - CloseableVerificationKeyResolver verificationKeyResolver = (jws, nestingContext) -> - builder.jwk().getPublicKey(); - AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(options, verificationKeyResolver); - handler.init(verificationKeyResolver, accessTokenValidator); - return handler; + private JwtValidator createJwtValidator(Map configs, CloseableVerificationKeyResolver verificationKeyResolver) { + return new DefaultJwtValidator(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, verificationKeyResolver); } - private String createAccessKey(String header, String payload, String signature) { - Base64.Encoder enc = Base64.getEncoder(); - header = enc.encodeToString(Utils.utf8(header)); - payload = enc.encodeToString(Utils.utf8(payload)); - signature = enc.encodeToString(Utils.utf8(signature)); - return String.format("%s.%s.%s", header, payload, signature); + private CloseableVerificationKeyResolver createVerificationKeyResolver(AccessTokenBuilder builder) { + return (jws, nestingContext) -> builder.jwk().getPublicKey(); } - } diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/AccessTokenValidatorFactoryTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/AccessTokenValidatorFactoryTest.java deleted file mode 100644 index 2fd02e3f9a8..00000000000 --- a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/AccessTokenValidatorFactoryTest.java +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.kafka.common.security.oauthbearer.internals.secured; - -import org.apache.kafka.common.KafkaException; -import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler; - -import org.junit.jupiter.api.Test; - -import java.io.IOException; -import java.util.Map; - -public class AccessTokenValidatorFactoryTest extends OAuthBearerTest { - - @Test - public void testConfigureThrowsExceptionOnAccessTokenValidatorInit() { - OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler(); - AccessTokenRetriever accessTokenRetriever = new AccessTokenRetriever() { - @Override - public void init() throws IOException { - throw new IOException("My init had an error!"); - } - @Override - public String retrieve() { - return "dummy"; - } - }; - - Map configs = getSaslConfigs(); - AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(configs); - - assertThrowsWithMessage( - KafkaException.class, () -> handler.init(accessTokenRetriever, accessTokenValidator), "encountered an error when initializing"); - } - - @Test - public void testConfigureThrowsExceptionOnAccessTokenValidatorClose() { - OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler(); - AccessTokenRetriever accessTokenRetriever = new AccessTokenRetriever() { - @Override - public void close() throws IOException { - throw new IOException("My close had an error!"); - } - @Override - public String retrieve() { - return "dummy"; - } - }; - - Map configs = getSaslConfigs(); - AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(configs); - handler.init(accessTokenRetriever, accessTokenValidator); - - // Basically asserting this doesn't throw an exception :( - handler.close(); - } - -} diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/ValidatorAccessTokenValidatorTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/BrokerJwtValidatorTest.java similarity index 89% rename from clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/ValidatorAccessTokenValidatorTest.java rename to clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/BrokerJwtValidatorTest.java index 4db20e9ee10..3b06bf07dec 100644 --- a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/ValidatorAccessTokenValidatorTest.java +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/BrokerJwtValidatorTest.java @@ -28,11 +28,11 @@ import java.util.Collections; import static org.junit.jupiter.api.Assertions.assertEquals; -public class ValidatorAccessTokenValidatorTest extends AccessTokenValidatorTest { +public class BrokerJwtValidatorTest extends JwtValidatorTest { @Override - protected AccessTokenValidator createAccessTokenValidator(AccessTokenBuilder builder) { - return new ValidatorAccessTokenValidator(30, + protected JwtValidator createJwtValidator(AccessTokenBuilder builder) { + return new BrokerJwtValidator(30, Collections.emptySet(), null, (jws, nestingContext) -> builder.jwk().getKey(), @@ -72,7 +72,7 @@ public class ValidatorAccessTokenValidatorTest extends AccessTokenValidatorTest .addCustomClaim(subClaimName, subject) .subjectClaimName(subClaimName) .subject(null); - AccessTokenValidator validator = createAccessTokenValidator(tokenBuilder); + JwtValidator validator = createJwtValidator(tokenBuilder); // Validation should succeed (e.g. signature verification) even if sub claim is missing OAuthBearerToken token = validator.validate(tokenBuilder.build()); @@ -82,7 +82,7 @@ public class ValidatorAccessTokenValidatorTest extends AccessTokenValidatorTest private void testEncryptionAlgorithm(PublicJsonWebKey jwk, String alg) throws Exception { AccessTokenBuilder builder = new AccessTokenBuilder().jwk(jwk).alg(alg); - AccessTokenValidator validator = createAccessTokenValidator(builder); + JwtValidator validator = createJwtValidator(builder); String accessToken = builder.build(); OAuthBearerToken token = validator.validate(accessToken); diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/LoginAccessTokenValidatorTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/ClientJwtValidatorTest.java similarity index 76% rename from clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/LoginAccessTokenValidatorTest.java rename to clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/ClientJwtValidatorTest.java index fc2e3d2a2e8..280aecd82c3 100644 --- a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/LoginAccessTokenValidatorTest.java +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/ClientJwtValidatorTest.java @@ -17,11 +17,11 @@ package org.apache.kafka.common.security.oauthbearer.internals.secured; -public class LoginAccessTokenValidatorTest extends AccessTokenValidatorTest { +public class ClientJwtValidatorTest extends JwtValidatorTest { @Override - protected AccessTokenValidator createAccessTokenValidator(AccessTokenBuilder builder) { - return new LoginAccessTokenValidator(builder.scopeClaimName(), builder.subjectClaimName()); + protected JwtValidator createJwtValidator(AccessTokenBuilder builder) { + return new ClientJwtValidator(builder.scopeClaimName(), builder.subjectClaimName()); } } diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/AccessTokenRetrieverFactoryTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/DefaultJwtRetrieverTest.java similarity index 56% rename from clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/AccessTokenRetrieverFactoryTest.java rename to clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/DefaultJwtRetrieverTest.java index 3e85f7b0ce4..83fd57713b0 100644 --- a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/AccessTokenRetrieverFactoryTest.java +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/DefaultJwtRetrieverTest.java @@ -18,6 +18,7 @@ package org.apache.kafka.common.security.oauthbearer.internals.secured; import org.apache.kafka.common.config.ConfigException; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; @@ -26,7 +27,9 @@ import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; import java.io.File; +import java.io.IOException; import java.util.Collections; +import java.util.HashMap; import java.util.Map; import java.util.stream.Stream; @@ -34,9 +37,13 @@ import static org.apache.kafka.common.config.SaslConfigs.DEFAULT_SASL_OAUTHBEARE import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_HEADER_URLENCODE; import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL; import static org.apache.kafka.common.config.internals.BrokerSecurityConfigs.ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG; +import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler.CLIENT_ID_CONFIG; +import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler.CLIENT_SECRET_CONFIG; +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; -public class AccessTokenRetrieverFactoryTest extends OAuthBearerTest { +public class DefaultJwtRetrieverTest extends OAuthBearerTest { @AfterEach public void tearDown() throws Exception { @@ -44,7 +51,7 @@ public class AccessTokenRetrieverFactoryTest extends OAuthBearerTest { } @Test - public void testConfigureRefreshingFileAccessTokenRetriever() throws Exception { + public void testConfigureRefreshingFileJwtRetriever() throws Exception { String expected = "{}"; File tmpDir = createTempDir("access-token"); @@ -54,31 +61,37 @@ public class AccessTokenRetrieverFactoryTest extends OAuthBearerTest { Map configs = Collections.singletonMap(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, accessTokenFile.toURI().toString()); Map jaasConfig = Collections.emptyMap(); - try (AccessTokenRetriever accessTokenRetriever = AccessTokenRetrieverFactory.create(configs, jaasConfig)) { - accessTokenRetriever.init(); - assertEquals(expected, accessTokenRetriever.retrieve()); + try (JwtRetriever jwtRetriever = new DefaultJwtRetriever(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, jaasConfig)) { + jwtRetriever.init(); + assertEquals(expected, jwtRetriever.retrieve()); } } @Test - public void testConfigureRefreshingFileAccessTokenRetrieverWithInvalidDirectory() { + public void testConfigureRefreshingFileJwtRetrieverWithInvalidDirectory() throws IOException { // Should fail because the parent path doesn't exist. String file = new File("/tmp/this-directory-does-not-exist/foo.json").toURI().toString(); System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, file); Map configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, file); Map jaasConfig = Collections.emptyMap(); - assertThrowsWithMessage(ConfigException.class, () -> AccessTokenRetrieverFactory.create(configs, jaasConfig), "that doesn't exist"); + + try (JwtRetriever jwtRetriever = new DefaultJwtRetriever(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, jaasConfig)) { + assertThrowsWithMessage(ConfigException.class, jwtRetriever::init, "that doesn't exist"); + } } @Test - public void testConfigureRefreshingFileAccessTokenRetrieverWithInvalidFile() throws Exception { + public void testConfigureRefreshingFileJwtRetrieverWithInvalidFile() throws Exception { // Should fail because while the parent path exists, the file itself doesn't. File tmpDir = createTempDir("this-directory-does-exist"); File accessTokenFile = new File(tmpDir, "this-file-does-not-exist.json"); System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, accessTokenFile.toURI().toString()); Map configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, accessTokenFile.toURI().toString()); Map jaasConfig = Collections.emptyMap(); - assertThrowsWithMessage(ConfigException.class, () -> AccessTokenRetrieverFactory.create(configs, jaasConfig), "that doesn't exist"); + + try (JwtRetriever jwtRetriever = new DefaultJwtRetriever(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, jaasConfig)) { + assertThrowsWithMessage(ConfigException.class, jwtRetriever::init, "that doesn't exist"); + } } @Test @@ -87,15 +100,53 @@ public class AccessTokenRetrieverFactoryTest extends OAuthBearerTest { File tmpDir = createTempDir("not_allowed"); File accessTokenFile = new File(tmpDir, "not_allowed.json"); Map configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, accessTokenFile.toURI().toString()); - assertThrowsWithMessage(ConfigException.class, () -> AccessTokenRetrieverFactory.create(configs, Collections.emptyMap()), - ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG); + + try (JwtRetriever jwtRetriever = new DefaultJwtRetriever(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, Collections.emptyMap())) { + assertThrowsWithMessage(ConfigException.class, jwtRetriever::init, ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG); + } + } + + @Test + public void testConfigureWithAccessTokenFile() throws Exception { + String expected = "{}"; + + File tmpDir = createTempDir("access-token"); + File accessTokenFile = createTempFile(tmpDir, "access-token-", ".json", expected); + System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, accessTokenFile.toURI().toString()); + + Map configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, accessTokenFile.toURI().toString()); + + DefaultJwtRetriever jwtRetriever = new DefaultJwtRetriever( + configs, + OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, + Map.of() + ); + assertDoesNotThrow(jwtRetriever::init); + assertInstanceOf(FileJwtRetriever.class, jwtRetriever.delegate()); + } + + @Test + public void testConfigureWithAccessClientCredentials() { + Map configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, "http://www.example.com"); + System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, "http://www.example.com"); + Map jaasConfigs = new HashMap<>(); + jaasConfigs.put(CLIENT_ID_CONFIG, "an ID"); + jaasConfigs.put(CLIENT_SECRET_CONFIG, "a secret"); + + DefaultJwtRetriever jwtRetriever = new DefaultJwtRetriever( + configs, + OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, + jaasConfigs + ); + assertDoesNotThrow(jwtRetriever::init); + assertInstanceOf(HttpJwtRetriever.class, jwtRetriever.delegate()); } @ParameterizedTest @MethodSource("urlencodeHeaderSupplier") public void testUrlencodeHeader(Map configs, boolean expectedValue) { ConfigurationUtils cu = new ConfigurationUtils(configs); - boolean actualValue = AccessTokenRetrieverFactory.validateUrlencodeHeader(cu); + boolean actualValue = DefaultJwtRetriever.validateUrlencodeHeader(cu); assertEquals(expectedValue, actualValue); } diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/DefaultJwtValidatorTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/DefaultJwtValidatorTest.java new file mode 100644 index 00000000000..9d136b72b14 --- /dev/null +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/DefaultJwtValidatorTest.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.kafka.common.security.oauthbearer.internals.secured; + +import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule; + +import org.jose4j.jws.AlgorithmIdentifiers; +import org.junit.jupiter.api.Test; + +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; + +public class DefaultJwtValidatorTest extends OAuthBearerTest { + + @Test + public void testConfigureWithVerificationKeyResolver() { + AccessTokenBuilder builder = new AccessTokenBuilder() + .alg(AlgorithmIdentifiers.RSA_USING_SHA256); + CloseableVerificationKeyResolver verificationKeyResolver = createVerificationKeyResolver(builder); + Map configs = getSaslConfigs(); + DefaultJwtValidator jwtValidator = new DefaultJwtValidator( + configs, + OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, + verificationKeyResolver + ); + assertDoesNotThrow(jwtValidator::init); + assertInstanceOf(BrokerJwtValidator.class, jwtValidator.delegate()); + } + + @Test + public void testConfigureWithoutVerificationKeyResolver() { + Map configs = getSaslConfigs(); + DefaultJwtValidator jwtValidator = new DefaultJwtValidator( + configs, + OAuthBearerLoginModule.OAUTHBEARER_MECHANISM + ); + assertDoesNotThrow(jwtValidator::init); + assertInstanceOf(ClientJwtValidator.class, jwtValidator.delegate()); + } + + private CloseableVerificationKeyResolver createVerificationKeyResolver(AccessTokenBuilder builder) { + return (jws, nestingContext) -> builder.jwk().getPublicKey(); + } +} diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/HttpAccessTokenRetrieverTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/HttpJwtRetrieverTest.java similarity index 73% rename from clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/HttpAccessTokenRetrieverTest.java rename to clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/HttpJwtRetrieverTest.java index 8b1c5a37065..0bd903300ff 100644 --- a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/HttpAccessTokenRetrieverTest.java +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/HttpJwtRetrieverTest.java @@ -39,20 +39,20 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; -public class HttpAccessTokenRetrieverTest extends OAuthBearerTest { +public class HttpJwtRetrieverTest extends OAuthBearerTest { @Test public void test() throws IOException { String expectedResponse = "Hiya, buddy"; HttpURLConnection mockedCon = createHttpURLConnection(expectedResponse); - String response = HttpAccessTokenRetriever.post(mockedCon, null, null, null, null); + String response = HttpJwtRetriever.post(mockedCon, null, null, null, null); assertEquals(expectedResponse, response); } @Test public void testEmptyResponse() throws IOException { HttpURLConnection mockedCon = createHttpURLConnection(""); - assertThrows(IOException.class, () -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null)); + assertThrows(IOException.class, () -> HttpJwtRetriever.post(mockedCon, null, null, null, null)); } @Test @@ -60,7 +60,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest { HttpURLConnection mockedCon = createHttpURLConnection("dummy"); when(mockedCon.getInputStream()).thenThrow(new IOException("Can't read")); - assertThrows(IOException.class, () -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null)); + assertThrows(IOException.class, () -> HttpJwtRetriever.post(mockedCon, null, null, null, null)); } @Test @@ -72,7 +72,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest { .getBytes(StandardCharsets.UTF_8))); when(mockedCon.getResponseCode()).thenReturn(HttpURLConnection.HTTP_BAD_REQUEST); UnretryableException ioe = assertThrows(UnretryableException.class, - () -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null)); + () -> HttpJwtRetriever.post(mockedCon, null, null, null, null)); assertTrue(ioe.getMessage().contains("{\"some_arg\" - \"some problem with arg\"}")); } @@ -85,7 +85,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest { .getBytes(StandardCharsets.UTF_8))); when(mockedCon.getResponseCode()).thenReturn(HttpURLConnection.HTTP_INTERNAL_ERROR); IOException ioe = assertThrows(IOException.class, - () -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null)); + () -> HttpJwtRetriever.post(mockedCon, null, null, null, null)); assertTrue(ioe.getMessage().contains("{\"some_arg\" - \"some problem with arg\"}")); // error response body has different keys @@ -93,7 +93,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest { "{\"errorCode\":\"some_arg\", \"errorSummary\":\"some problem with arg\"}" .getBytes(StandardCharsets.UTF_8))); ioe = assertThrows(IOException.class, - () -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null)); + () -> HttpJwtRetriever.post(mockedCon, null, null, null, null)); assertTrue(ioe.getMessage().contains("{\"some_arg\" - \"some problem with arg\"}")); // error response is valid json but unknown keys @@ -101,7 +101,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest { "{\"err\":\"some_arg\", \"err_des\":\"some problem with arg\"}" .getBytes(StandardCharsets.UTF_8))); ioe = assertThrows(IOException.class, - () -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null)); + () -> HttpJwtRetriever.post(mockedCon, null, null, null, null)); assertTrue(ioe.getMessage().contains("{\"err\":\"some_arg\", \"err_des\":\"some problem with arg\"}")); } @@ -113,7 +113,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest { "non json error output".getBytes(StandardCharsets.UTF_8))); when(mockedCon.getResponseCode()).thenReturn(HttpURLConnection.HTTP_INTERNAL_ERROR); IOException ioe = assertThrows(IOException.class, - () -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null)); + () -> HttpJwtRetriever.post(mockedCon, null, null, null, null)); assertTrue(ioe.getMessage().contains("{non json error output}")); } @@ -124,7 +124,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest { r.nextBytes(expected); InputStream in = new ByteArrayInputStream(expected); ByteArrayOutputStream out = new ByteArrayOutputStream(); - HttpAccessTokenRetriever.copy(in, out); + HttpJwtRetriever.copy(in, out); assertArrayEquals(expected, out.toByteArray()); } @@ -133,7 +133,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest { InputStream mockedIn = mock(InputStream.class); OutputStream out = new ByteArrayOutputStream(); when(mockedIn.read(any(byte[].class))).thenThrow(new IOException()); - assertThrows(IOException.class, () -> HttpAccessTokenRetriever.copy(mockedIn, out)); + assertThrows(IOException.class, () -> HttpJwtRetriever.copy(mockedIn, out)); } @Test @@ -143,7 +143,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest { ObjectNode node = mapper.createObjectNode(); node.put("access_token", expected); - String actual = HttpAccessTokenRetriever.parseAccessToken(mapper.writeValueAsString(node)); + String actual = HttpJwtRetriever.parseAccessToken(mapper.writeValueAsString(node)); assertEquals(expected, actual); } @@ -153,7 +153,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest { ObjectNode node = mapper.createObjectNode(); node.put("access_token", ""); - assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.parseAccessToken(mapper.writeValueAsString(node))); + assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.parseAccessToken(mapper.writeValueAsString(node))); } @Test @@ -162,12 +162,12 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest { ObjectNode node = mapper.createObjectNode(); node.put("sub", "jdoe"); - assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.parseAccessToken(mapper.writeValueAsString(node))); + assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.parseAccessToken(mapper.writeValueAsString(node))); } @Test public void testParseAccessTokenInvalidJson() { - assertThrows(IOException.class, () -> HttpAccessTokenRetriever.parseAccessToken("not valid JSON")); + assertThrows(IOException.class, () -> HttpJwtRetriever.parseAccessToken("not valid JSON")); } @Test @@ -184,27 +184,27 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest { } private void assertAuthorizationHeader(String clientId, String clientSecret, boolean urlencode, String expected) { - String actual = HttpAccessTokenRetriever.formatAuthorizationHeader(clientId, clientSecret, urlencode); + String actual = HttpJwtRetriever.formatAuthorizationHeader(clientId, clientSecret, urlencode); assertEquals(expected, actual, String.format("Expected the HTTP Authorization header generated for client ID \"%s\" and client secret \"%s\" to match", clientId, clientSecret)); } @Test public void testFormatAuthorizationHeaderMissingValues() { - assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader(null, "secret", false)); - assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader("id", null, false)); - assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader(null, null, false)); - assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader("", "secret", false)); - assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader("id", "", false)); - assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader("", "", false)); - assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader(" ", "secret", false)); - assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader("id", " ", false)); - assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader(" ", " ", false)); + assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader(null, "secret", false)); + assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader("id", null, false)); + assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader(null, null, false)); + assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader("", "secret", false)); + assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader("id", "", false)); + assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader("", "", false)); + assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader(" ", "secret", false)); + assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader("id", " ", false)); + assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader(" ", " ", false)); } @Test public void testFormatRequestBody() { String expected = "grant_type=client_credentials&scope=scope"; - String actual = HttpAccessTokenRetriever.formatRequestBody("scope"); + String actual = HttpJwtRetriever.formatRequestBody("scope"); assertEquals(expected, actual); } @@ -214,24 +214,24 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest { String exclamationMark = "%21"; String expected = String.format("grant_type=client_credentials&scope=earth+is+great%s", exclamationMark); - String actual = HttpAccessTokenRetriever.formatRequestBody("earth is great!"); + String actual = HttpJwtRetriever.formatRequestBody("earth is great!"); assertEquals(expected, actual); expected = String.format("grant_type=client_credentials&scope=what+on+earth%s%s%s%s%s", questionMark, exclamationMark, questionMark, exclamationMark, questionMark); - actual = HttpAccessTokenRetriever.formatRequestBody("what on earth?!?!?"); + actual = HttpJwtRetriever.formatRequestBody("what on earth?!?!?"); assertEquals(expected, actual); } @Test public void testFormatRequestBodyMissingValues() { String expected = "grant_type=client_credentials"; - String actual = HttpAccessTokenRetriever.formatRequestBody(null); + String actual = HttpJwtRetriever.formatRequestBody(null); assertEquals(expected, actual); - actual = HttpAccessTokenRetriever.formatRequestBody(""); + actual = HttpJwtRetriever.formatRequestBody(""); assertEquals(expected, actual); - actual = HttpAccessTokenRetriever.formatRequestBody(" "); + actual = HttpJwtRetriever.formatRequestBody(" "); assertEquals(expected, actual); } diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/AccessTokenValidatorTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/JwtValidatorTest.java similarity index 80% rename from clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/AccessTokenValidatorTest.java rename to clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/JwtValidatorTest.java index 0adaf34bbbe..bfbf29d0266 100644 --- a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/AccessTokenValidatorTest.java +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/JwtValidatorTest.java @@ -26,42 +26,42 @@ import org.junit.jupiter.api.TestInstance.Lifecycle; import static org.junit.jupiter.api.Assertions.assertThrows; @TestInstance(Lifecycle.PER_CLASS) -public abstract class AccessTokenValidatorTest extends OAuthBearerTest { +public abstract class JwtValidatorTest extends OAuthBearerTest { - protected abstract AccessTokenValidator createAccessTokenValidator(AccessTokenBuilder accessTokenBuilder) throws Exception; + protected abstract JwtValidator createJwtValidator(AccessTokenBuilder accessTokenBuilder) throws Exception; - protected AccessTokenValidator createAccessTokenValidator() throws Exception { + protected JwtValidator createJwtValidator() throws Exception { AccessTokenBuilder builder = new AccessTokenBuilder(); - return createAccessTokenValidator(builder); + return createJwtValidator(builder); } @Test public void testNull() throws Exception { - AccessTokenValidator validator = createAccessTokenValidator(); + JwtValidator validator = createJwtValidator(); assertThrowsWithMessage(ValidateException.class, () -> validator.validate(null), "Malformed JWT provided; expected three sections (header, payload, and signature)"); } @Test public void testEmptyString() throws Exception { - AccessTokenValidator validator = createAccessTokenValidator(); + JwtValidator validator = createJwtValidator(); assertThrowsWithMessage(ValidateException.class, () -> validator.validate(""), "Malformed JWT provided; expected three sections (header, payload, and signature)"); } @Test public void testWhitespace() throws Exception { - AccessTokenValidator validator = createAccessTokenValidator(); + JwtValidator validator = createJwtValidator(); assertThrowsWithMessage(ValidateException.class, () -> validator.validate(" "), "Malformed JWT provided; expected three sections (header, payload, and signature)"); } @Test public void testEmptySections() throws Exception { - AccessTokenValidator validator = createAccessTokenValidator(); + JwtValidator validator = createJwtValidator(); assertThrowsWithMessage(ValidateException.class, () -> validator.validate(".."), "Malformed JWT provided; expected three sections (header, payload, and signature)"); } @Test public void testMissingHeader() throws Exception { - AccessTokenValidator validator = createAccessTokenValidator(); + JwtValidator validator = createJwtValidator(); String header = ""; String payload = createBase64JsonJwtSection(node -> { }); String signature = ""; @@ -71,7 +71,7 @@ public abstract class AccessTokenValidatorTest extends OAuthBearerTest { @Test public void testMissingPayload() throws Exception { - AccessTokenValidator validator = createAccessTokenValidator(); + JwtValidator validator = createJwtValidator(); String header = createBase64JsonJwtSection(node -> node.put(HeaderParameterNames.ALGORITHM, AlgorithmIdentifiers.NONE)); String payload = ""; String signature = ""; @@ -81,7 +81,7 @@ public abstract class AccessTokenValidatorTest extends OAuthBearerTest { @Test public void testMissingSignature() throws Exception { - AccessTokenValidator validator = createAccessTokenValidator(); + JwtValidator validator = createJwtValidator(); String header = createBase64JsonJwtSection(node -> node.put(HeaderParameterNames.ALGORITHM, AlgorithmIdentifiers.NONE)); String payload = createBase64JsonJwtSection(node -> { }); String signature = ""; diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/OAuthBearerTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/OAuthBearerTest.java index 7f20b9464fa..8e82092f28d 100644 --- a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/OAuthBearerTest.java +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/internals/secured/OAuthBearerTest.java @@ -19,9 +19,6 @@ package org.apache.kafka.common.security.oauthbearer.internals.secured; import org.apache.kafka.common.config.AbstractConfig; import org.apache.kafka.common.config.ConfigDef; -import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; -import org.apache.kafka.common.security.authenticator.TestJaasConfig; -import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule; import org.apache.kafka.common.utils.Utils; import com.fasterxml.jackson.databind.ObjectMapper; @@ -52,8 +49,6 @@ import java.util.Map; import java.util.concurrent.ExecutionException; import java.util.function.Consumer; -import javax.security.auth.login.AppConfigurationEntry; - import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; @@ -80,18 +75,6 @@ public abstract class OAuthBearerTest { expectedSubstring)); } - protected void configureHandler(AuthenticateCallbackHandler handler, - Map configs, - Map jaasConfig) { - TestJaasConfig config = new TestJaasConfig(); - config.createOrUpdateEntry("KafkaClient", OAuthBearerLoginModule.class.getName(), jaasConfig); - AppConfigurationEntry kafkaClient = config.getAppConfigurationEntry("KafkaClient")[0]; - - handler.configure(configs, - OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, - Collections.singletonList(kafkaClient)); - } - protected String createBase64JsonJwtSection(Consumer c) { String json = createJsonJwtSection(c); @@ -212,4 +195,11 @@ public abstract class OAuthBearerTest { return jwk; } + protected String createAccessKey(String header, String payload, String signature) { + Base64.Encoder enc = Base64.getEncoder(); + header = enc.encodeToString(Utils.utf8(header)); + payload = enc.encodeToString(Utils.utf8(payload)); + signature = enc.encodeToString(Utils.utf8(signature)); + return String.format("%s.%s.%s", header, payload, signature); + } } \ No newline at end of file diff --git a/tools/src/main/java/org/apache/kafka/tools/OAuthCompatibilityTool.java b/tools/src/main/java/org/apache/kafka/tools/OAuthCompatibilityTool.java index 485146aea7e..7852c3a07e0 100644 --- a/tools/src/main/java/org/apache/kafka/tools/OAuthCompatibilityTool.java +++ b/tools/src/main/java/org/apache/kafka/tools/OAuthCompatibilityTool.java @@ -24,11 +24,12 @@ import org.apache.kafka.common.config.ConfigException; import org.apache.kafka.common.config.SaslConfigs; import org.apache.kafka.common.config.SslConfigs; import org.apache.kafka.common.config.types.Password; -import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenRetriever; -import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenRetrieverFactory; -import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidator; -import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidatorFactory; +import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule; import org.apache.kafka.common.security.oauthbearer.internals.secured.CloseableVerificationKeyResolver; +import org.apache.kafka.common.security.oauthbearer.internals.secured.DefaultJwtRetriever; +import org.apache.kafka.common.security.oauthbearer.internals.secured.DefaultJwtValidator; +import org.apache.kafka.common.security.oauthbearer.internals.secured.JwtRetriever; +import org.apache.kafka.common.security.oauthbearer.internals.secured.JwtValidator; import org.apache.kafka.common.security.oauthbearer.internals.secured.VerificationKeyResolverFactory; import org.apache.kafka.common.utils.Exit; @@ -139,16 +140,19 @@ public class OAuthCompatibilityTool { { // Client side... - try (AccessTokenRetriever atr = AccessTokenRetrieverFactory.create(configs, jaasConfigs)) { + try (JwtRetriever atr = new DefaultJwtRetriever(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, jaasConfigs)) { atr.init(); - AccessTokenValidator atv = AccessTokenValidatorFactory.create(configs); - System.out.println("PASSED 1/5: client configuration"); - accessToken = atr.retrieve(); - System.out.println("PASSED 2/5: client JWT retrieval"); + try (JwtValidator atv = new DefaultJwtValidator(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM)) { + atv.init(); + System.out.println("PASSED 1/5: client configuration"); - atv.validate(accessToken); - System.out.println("PASSED 3/5: client JWT validation"); + accessToken = atr.retrieve(); + System.out.println("PASSED 2/5: client JWT retrieval"); + + atv.validate(accessToken); + System.out.println("PASSED 3/5: client JWT validation"); + } } } @@ -156,11 +160,14 @@ public class OAuthCompatibilityTool { // Broker side... try (CloseableVerificationKeyResolver vkr = VerificationKeyResolverFactory.create(configs, jaasConfigs)) { vkr.init(); - AccessTokenValidator atv = AccessTokenValidatorFactory.create(configs, vkr); - System.out.println("PASSED 4/5: broker configuration"); - atv.validate(accessToken); - System.out.println("PASSED 5/5: broker JWT validation"); + try (JwtValidator atv = new DefaultJwtValidator(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, vkr)) { + atv.init(); + System.out.println("PASSED 4/5: broker configuration"); + + atv.validate(accessToken); + System.out.println("PASSED 5/5: broker JWT validation"); + } } }