KAFKA-18847: Refactor OAuth layer to improve reusability 1/N (#19622)
CI / build (push) Waiting to run Details

Rename `AccessTokenRetriever` and `AccessTokenValidator` to
`JwtRetriever` and `JwtValidator`, respectively. Also converting the
factory pattern classes `AccessTokenRetrieverFactory` and
`AccessTokenValidatorFactory` into delegate/wrapper classes
`DefaultJwtRetriever` and `DefaultJwtValidator`, respectively.

These are all internal changes, no configuration, user APIs, RPCs, etc.
were changed.

Reviewers: Manikumar Reddy <manikumar@confluent.io>, Ken Huang
 <s7133700@gmail.com>, Lianet Magrans <lmagrans@confluent.io>

---------

Co-authored-by: Ken Huang <s7133700@gmail.com>
This commit is contained in:
Kirk True 2025-05-13 09:35:20 -07:00 committed by GitHub
parent c16c240bd1
commit c60c83aaba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 598 additions and 476 deletions

View File

@ -24,12 +24,13 @@ import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
import org.apache.kafka.common.security.auth.SaslExtensions; import org.apache.kafka.common.security.auth.SaslExtensions;
import org.apache.kafka.common.security.auth.SaslExtensionsCallback; import org.apache.kafka.common.security.auth.SaslExtensionsCallback;
import org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerClientInitialResponse; import org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerClientInitialResponse;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenRetriever; import org.apache.kafka.common.security.oauthbearer.internals.secured.DefaultJwtRetriever;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenRetrieverFactory; import org.apache.kafka.common.security.oauthbearer.internals.secured.DefaultJwtValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidatorFactory;
import org.apache.kafka.common.security.oauthbearer.internals.secured.JaasOptionsUtils; import org.apache.kafka.common.security.oauthbearer.internals.secured.JaasOptionsUtils;
import org.apache.kafka.common.security.oauthbearer.internals.secured.JwtRetriever;
import org.apache.kafka.common.security.oauthbearer.internals.secured.JwtValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.ValidateException; import org.apache.kafka.common.security.oauthbearer.internals.secured.ValidateException;
import org.apache.kafka.common.utils.Utils;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -179,55 +180,48 @@ public class OAuthBearerLoginCallbackHandler implements AuthenticateCallbackHand
private Map<String, Object> moduleOptions; private Map<String, Object> moduleOptions;
private AccessTokenRetriever accessTokenRetriever; private JwtRetriever jwtRetriever;
private AccessTokenValidator accessTokenValidator; private JwtValidator jwtValidator;
private boolean isInitialized = false;
@Override @Override
public void configure(Map<String, ?> configs, String saslMechanism, List<AppConfigurationEntry> jaasConfigEntries) { public void configure(Map<String, ?> configs, String saslMechanism, List<AppConfigurationEntry> jaasConfigEntries) {
moduleOptions = JaasOptionsUtils.getOptions(saslMechanism, jaasConfigEntries); Map<String, Object> moduleOptions = JaasOptionsUtils.getOptions(saslMechanism, jaasConfigEntries);
AccessTokenRetriever accessTokenRetriever = AccessTokenRetrieverFactory.create(configs, saslMechanism, moduleOptions); JwtRetriever jwtRetriever = new DefaultJwtRetriever(configs, saslMechanism, moduleOptions);
AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(configs, saslMechanism); JwtValidator jwtValidator = new DefaultJwtValidator(configs, saslMechanism);
init(accessTokenRetriever, accessTokenValidator); init(moduleOptions, jwtRetriever, jwtValidator);
}
public void init(AccessTokenRetriever accessTokenRetriever, AccessTokenValidator accessTokenValidator) {
this.accessTokenRetriever = accessTokenRetriever;
this.accessTokenValidator = accessTokenValidator;
try {
this.accessTokenRetriever.init();
} catch (IOException e) {
throw new KafkaException("The OAuth login configuration encountered an error when initializing the AccessTokenRetriever", e);
}
isInitialized = true;
} }
/* /*
* Package-visible for testing. * Package-visible for testing.
*/ */
void init(Map<String, Object> moduleOptions, JwtRetriever jwtRetriever, JwtValidator jwtValidator) {
this.moduleOptions = moduleOptions;
this.jwtRetriever = jwtRetriever;
this.jwtValidator = jwtValidator;
AccessTokenRetriever getAccessTokenRetriever() { try {
return accessTokenRetriever; this.jwtRetriever.init();
} } catch (IOException e) {
throw new KafkaException("The OAuth login callback encountered an error when initializing the JwtRetriever", e);
}
@Override try {
public void close() { this.jwtValidator.init();
if (accessTokenRetriever != null) { } catch (IOException e) {
try { throw new KafkaException("The OAuth login callback encountered an error when initializing the JwtValidator", e);
this.accessTokenRetriever.close();
} catch (IOException e) {
log.warn("The OAuth login configuration encountered an error when closing the AccessTokenRetriever", e);
}
} }
} }
@Override
public void close() {
Utils.closeQuietly(jwtRetriever, "JWT retriever");
Utils.closeQuietly(jwtValidator, "JWT validator");
}
@Override @Override
public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
checkInitialized(); checkConfigured();
for (Callback callback : callbacks) { for (Callback callback : callbacks) {
if (callback instanceof OAuthBearerTokenCallback) { if (callback instanceof OAuthBearerTokenCallback) {
@ -241,11 +235,11 @@ public class OAuthBearerLoginCallbackHandler implements AuthenticateCallbackHand
} }
private void handleTokenCallback(OAuthBearerTokenCallback callback) throws IOException { private void handleTokenCallback(OAuthBearerTokenCallback callback) throws IOException {
checkInitialized(); checkConfigured();
String accessToken = accessTokenRetriever.retrieve(); String accessToken = jwtRetriever.retrieve();
try { try {
OAuthBearerToken token = accessTokenValidator.validate(accessToken); OAuthBearerToken token = jwtValidator.validate(accessToken);
callback.token(token); callback.token(token);
} catch (ValidateException e) { } catch (ValidateException e) {
log.warn(e.getMessage(), e); log.warn(e.getMessage(), e);
@ -254,7 +248,7 @@ public class OAuthBearerLoginCallbackHandler implements AuthenticateCallbackHand
} }
private void handleExtensionsCallback(SaslExtensionsCallback callback) { private void handleExtensionsCallback(SaslExtensionsCallback callback) {
checkInitialized(); checkConfigured();
Map<String, String> extensions = new HashMap<>(); Map<String, String> extensions = new HashMap<>();
@ -286,9 +280,9 @@ public class OAuthBearerLoginCallbackHandler implements AuthenticateCallbackHand
callback.extensions(saslExtensions); callback.extensions(saslExtensions);
} }
private void checkInitialized() { private void checkConfigured() {
if (!isInitialized) if (moduleOptions == null || jwtRetriever == null || jwtValidator == null)
throw new IllegalStateException(String.format("To use %s, first call the configure or init method", getClass().getSimpleName())); throw new IllegalStateException(String.format("To use %s, first call the configure method", getClass().getSimpleName()));
} }
} }

View File

@ -19,13 +19,14 @@ package org.apache.kafka.common.security.oauthbearer;
import org.apache.kafka.common.KafkaException; import org.apache.kafka.common.KafkaException;
import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidatorFactory;
import org.apache.kafka.common.security.oauthbearer.internals.secured.CloseableVerificationKeyResolver; import org.apache.kafka.common.security.oauthbearer.internals.secured.CloseableVerificationKeyResolver;
import org.apache.kafka.common.security.oauthbearer.internals.secured.DefaultJwtValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.JaasOptionsUtils; import org.apache.kafka.common.security.oauthbearer.internals.secured.JaasOptionsUtils;
import org.apache.kafka.common.security.oauthbearer.internals.secured.JwtValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.RefreshingHttpsJwksVerificationKeyResolver; import org.apache.kafka.common.security.oauthbearer.internals.secured.RefreshingHttpsJwksVerificationKeyResolver;
import org.apache.kafka.common.security.oauthbearer.internals.secured.ValidateException; import org.apache.kafka.common.security.oauthbearer.internals.secured.ValidateException;
import org.apache.kafka.common.security.oauthbearer.internals.secured.VerificationKeyResolverFactory; import org.apache.kafka.common.security.oauthbearer.internals.secured.VerificationKeyResolverFactory;
import org.apache.kafka.common.utils.Utils;
import org.jose4j.jws.JsonWebSignature; import org.jose4j.jws.JsonWebSignature;
import org.jose4j.jwx.JsonWebStructure; import org.jose4j.jwx.JsonWebStructure;
@ -119,9 +120,7 @@ public class OAuthBearerValidatorCallbackHandler implements AuthenticateCallback
private CloseableVerificationKeyResolver verificationKeyResolver; private CloseableVerificationKeyResolver verificationKeyResolver;
private AccessTokenValidator accessTokenValidator; private JwtValidator jwtValidator;
private boolean isInitialized = false;
@Override @Override
public void configure(Map<String, ?> configs, String saslMechanism, List<AppConfigurationEntry> jaasConfigEntries) { public void configure(Map<String, ?> configs, String saslMechanism, List<AppConfigurationEntry> jaasConfigEntries) {
@ -135,37 +134,39 @@ public class OAuthBearerValidatorCallbackHandler implements AuthenticateCallback
new RefCountingVerificationKeyResolver(VerificationKeyResolverFactory.create(configs, saslMechanism, moduleOptions))); new RefCountingVerificationKeyResolver(VerificationKeyResolverFactory.create(configs, saslMechanism, moduleOptions)));
} }
AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(configs, saslMechanism, verificationKeyResolver); JwtValidator jwtValidator = new DefaultJwtValidator(configs, saslMechanism, verificationKeyResolver);
init(verificationKeyResolver, accessTokenValidator); init(verificationKeyResolver, jwtValidator);
} }
public void init(CloseableVerificationKeyResolver verificationKeyResolver, AccessTokenValidator accessTokenValidator) { /*
* Package-visible for testing.
*/
void init(CloseableVerificationKeyResolver verificationKeyResolver, JwtValidator jwtValidator) {
this.verificationKeyResolver = verificationKeyResolver; this.verificationKeyResolver = verificationKeyResolver;
this.accessTokenValidator = accessTokenValidator; this.jwtValidator = jwtValidator;
try { try {
verificationKeyResolver.init(); verificationKeyResolver.init();
} catch (Exception e) { } catch (Exception e) {
throw new KafkaException("The OAuth validator configuration encountered an error when initializing the VerificationKeyResolver", e); throw new KafkaException("The OAuth validator callback encountered an error when initializing the VerificationKeyResolver", e);
} }
isInitialized = true; try {
jwtValidator.init();
} catch (IOException e) {
throw new KafkaException("The OAuth validator callback encountered an error when initializing the JwtValidator", e);
}
} }
@Override @Override
public void close() { public void close() {
if (verificationKeyResolver != null) { Utils.closeQuietly(jwtValidator, "JWT validator");
try { Utils.closeQuietly(verificationKeyResolver, "JWT verification key resolver");
verificationKeyResolver.close();
} catch (Exception e) {
log.error(e.getMessage(), e);
}
}
} }
@Override @Override
public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
checkInitialized(); checkConfigured();
for (Callback callback : callbacks) { for (Callback callback : callbacks) {
if (callback instanceof OAuthBearerValidatorCallback) { if (callback instanceof OAuthBearerValidatorCallback) {
@ -179,12 +180,12 @@ public class OAuthBearerValidatorCallbackHandler implements AuthenticateCallback
} }
private void handleValidatorCallback(OAuthBearerValidatorCallback callback) { private void handleValidatorCallback(OAuthBearerValidatorCallback callback) {
checkInitialized(); checkConfigured();
OAuthBearerToken token; OAuthBearerToken token;
try { try {
token = accessTokenValidator.validate(callback.tokenValue()); token = jwtValidator.validate(callback.tokenValue());
callback.token(token); callback.token(token);
} catch (ValidateException e) { } catch (ValidateException e) {
log.warn(e.getMessage(), e); log.warn(e.getMessage(), e);
@ -193,14 +194,14 @@ public class OAuthBearerValidatorCallbackHandler implements AuthenticateCallback
} }
private void handleExtensionsValidatorCallback(OAuthBearerExtensionsValidatorCallback extensionsValidatorCallback) { private void handleExtensionsValidatorCallback(OAuthBearerExtensionsValidatorCallback extensionsValidatorCallback) {
checkInitialized(); checkConfigured();
extensionsValidatorCallback.inputExtensions().map().forEach((extensionName, v) -> extensionsValidatorCallback.valid(extensionName)); extensionsValidatorCallback.inputExtensions().map().forEach((extensionName, v) -> extensionsValidatorCallback.valid(extensionName));
} }
private void checkInitialized() { private void checkConfigured() {
if (!isInitialized) if (verificationKeyResolver == null || jwtValidator == null)
throw new IllegalStateException(String.format("To use %s, first call the configure or init method", getClass().getSimpleName())); throw new IllegalStateException(String.format("To use %s, first call the configure method", getClass().getSimpleName()));
} }
/** /**

View File

@ -1,73 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.common.security.oauthbearer.internals.secured;
import org.jose4j.keys.resolvers.VerificationKeyResolver;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_CLOCK_SKEW_SECONDS;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_EXPECTED_AUDIENCE;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_EXPECTED_ISSUER;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_SCOPE_CLAIM_NAME;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_SUB_CLAIM_NAME;
public class AccessTokenValidatorFactory {
public static AccessTokenValidator create(Map<String, ?> configs) {
return create(configs, (String) null);
}
public static AccessTokenValidator create(Map<String, ?> configs, String saslMechanism) {
ConfigurationUtils cu = new ConfigurationUtils(configs, saslMechanism);
String scopeClaimName = cu.get(SASL_OAUTHBEARER_SCOPE_CLAIM_NAME);
String subClaimName = cu.get(SASL_OAUTHBEARER_SUB_CLAIM_NAME);
return new LoginAccessTokenValidator(scopeClaimName, subClaimName);
}
public static AccessTokenValidator create(Map<String, ?> configs,
VerificationKeyResolver verificationKeyResolver) {
return create(configs, null, verificationKeyResolver);
}
public static AccessTokenValidator create(Map<String, ?> configs,
String saslMechanism,
VerificationKeyResolver verificationKeyResolver) {
ConfigurationUtils cu = new ConfigurationUtils(configs, saslMechanism);
Set<String> expectedAudiences = null;
List<String> l = cu.get(SASL_OAUTHBEARER_EXPECTED_AUDIENCE);
if (l != null)
expectedAudiences = Set.copyOf(l);
Integer clockSkew = cu.validateInteger(SASL_OAUTHBEARER_CLOCK_SKEW_SECONDS, false);
String expectedIssuer = cu.validateString(SASL_OAUTHBEARER_EXPECTED_ISSUER, false);
String scopeClaimName = cu.validateString(SASL_OAUTHBEARER_SCOPE_CLAIM_NAME);
String subClaimName = cu.validateString(SASL_OAUTHBEARER_SUB_CLAIM_NAME);
return new ValidatorAccessTokenValidator(clockSkew,
expectedAudiences,
expectedIssuer,
verificationKeyResolver,
scopeClaimName,
subClaimName);
}
}

View File

@ -38,7 +38,7 @@ import java.util.Set;
import static org.jose4j.jwa.AlgorithmConstraints.DISALLOW_NONE; import static org.jose4j.jwa.AlgorithmConstraints.DISALLOW_NONE;
/** /**
* ValidatorAccessTokenValidator is an implementation of {@link AccessTokenValidator} that is used * {@code BrokerJwtValidator} is an implementation of {@link JwtValidator} that is used
* by the broker to perform more extensive validation of the JWT access token that is received * by the broker to perform more extensive validation of the JWT access token that is received
* from the client, but ultimately from posting the client credentials to the OAuth/OIDC provider's * from the client, but ultimately from posting the client credentials to the OAuth/OIDC provider's
* token endpoint. * token endpoint.
@ -62,9 +62,9 @@ import static org.jose4j.jwa.AlgorithmConstraints.DISALLOW_NONE;
* </ol> * </ol>
*/ */
public class ValidatorAccessTokenValidator implements AccessTokenValidator { public class BrokerJwtValidator implements JwtValidator {
private static final Logger log = LoggerFactory.getLogger(ValidatorAccessTokenValidator.class); private static final Logger log = LoggerFactory.getLogger(BrokerJwtValidator.class);
private final JwtConsumer jwtConsumer; private final JwtConsumer jwtConsumer;
@ -73,7 +73,7 @@ public class ValidatorAccessTokenValidator implements AccessTokenValidator {
private final String subClaimName; private final String subClaimName;
/** /**
* Creates a new ValidatorAccessTokenValidator that will be used by the broker for more * Creates a new {@code BrokerJwtValidator} that will be used by the broker for more
* thorough validation of the JWT. * thorough validation of the JWT.
* *
* @param clockSkew The optional value (in seconds) to allow for differences * @param clockSkew The optional value (in seconds) to allow for differences
@ -112,12 +112,12 @@ public class ValidatorAccessTokenValidator implements AccessTokenValidator {
* @see VerificationKeyResolver * @see VerificationKeyResolver
*/ */
public ValidatorAccessTokenValidator(Integer clockSkew, public BrokerJwtValidator(Integer clockSkew,
Set<String> expectedAudiences, Set<String> expectedAudiences,
String expectedIssuer, String expectedIssuer,
VerificationKeyResolver verificationKeyResolver, VerificationKeyResolver verificationKeyResolver,
String scopeClaimName, String scopeClaimName,
String subClaimName) { String subClaimName) {
final JwtConsumerBuilder jwtConsumerBuilder = new JwtConsumerBuilder(); final JwtConsumerBuilder jwtConsumerBuilder = new JwtConsumerBuilder();
if (clockSkew != null) if (clockSkew != null)

View File

@ -33,7 +33,7 @@ import static org.apache.kafka.common.config.SaslConfigs.DEFAULT_SASL_OAUTHBEARE
import static org.apache.kafka.common.config.SaslConfigs.DEFAULT_SASL_OAUTHBEARER_SUB_CLAIM_NAME; import static org.apache.kafka.common.config.SaslConfigs.DEFAULT_SASL_OAUTHBEARER_SUB_CLAIM_NAME;
/** /**
* LoginAccessTokenValidator is an implementation of {@link AccessTokenValidator} that is used * {@code ClientJwtValidator} is an implementation of {@link JwtValidator} that is used
* by the client to perform some rudimentary validation of the JWT access token that is received * by the client to perform some rudimentary validation of the JWT access token that is received
* as part of the response from posting the client credentials to the OAuth/OIDC provider's * as part of the response from posting the client credentials to the OAuth/OIDC provider's
* token endpoint. * token endpoint.
@ -46,13 +46,13 @@ import static org.apache.kafka.common.config.SaslConfigs.DEFAULT_SASL_OAUTHBEARE
* <a href="https://tools.ietf.org/html/rfc6750#section-2.1">RFC 6750 Section 2.1</a> * <a href="https://tools.ietf.org/html/rfc6750#section-2.1">RFC 6750 Section 2.1</a>
* </li> * </li>
* <li>Basic conversion of the token into an in-memory map</li> * <li>Basic conversion of the token into an in-memory map</li>
* <li>Presence of scope, <code>exp</code>, subject, and <code>iat</code> claims</li> * <li>Presence of <code>scope</code>, <code>exp</code>, <code>subject</code>, and <code>iat</code> claims</li>
* </ol> * </ol>
*/ */
public class LoginAccessTokenValidator implements AccessTokenValidator { public class ClientJwtValidator implements JwtValidator {
private static final Logger log = LoggerFactory.getLogger(LoginAccessTokenValidator.class); private static final Logger log = LoggerFactory.getLogger(ClientJwtValidator.class);
public static final String EXPIRATION_CLAIM_NAME = "exp"; public static final String EXPIRATION_CLAIM_NAME = "exp";
@ -63,14 +63,14 @@ public class LoginAccessTokenValidator implements AccessTokenValidator {
private final String subClaimName; private final String subClaimName;
/** /**
* Creates a new LoginAccessTokenValidator that will be used by the client for lightweight * Creates a new {@code ClientJwtValidator} that will be used by the client for lightweight
* validation of the JWT. * validation of the JWT.
* *
* @param scopeClaimName Name of the scope claim to use; must be non-<code>null</code> * @param scopeClaimName Name of the scope claim to use; must be non-<code>null</code>
* @param subClaimName Name of the subject claim to use; must be non-<code>null</code> * @param subClaimName Name of the subject claim to use; must be non-<code>null</code>
*/ */
public LoginAccessTokenValidator(String scopeClaimName, String subClaimName) { public ClientJwtValidator(String scopeClaimName, String subClaimName) {
this.scopeClaimName = ClaimValidationUtils.validateClaimNameOverride(DEFAULT_SASL_OAUTHBEARER_SCOPE_CLAIM_NAME, scopeClaimName); this.scopeClaimName = ClaimValidationUtils.validateClaimNameOverride(DEFAULT_SASL_OAUTHBEARER_SCOPE_CLAIM_NAME, scopeClaimName);
this.subClaimName = ClaimValidationUtils.validateClaimNameOverride(DEFAULT_SASL_OAUTHBEARER_SUB_CLAIM_NAME, subClaimName); this.subClaimName = ClaimValidationUtils.validateClaimNameOverride(DEFAULT_SASL_OAUTHBEARER_SUB_CLAIM_NAME, subClaimName);
} }

View File

@ -18,7 +18,9 @@
package org.apache.kafka.common.security.oauthbearer.internals.secured; package org.apache.kafka.common.security.oauthbearer.internals.secured;
import org.apache.kafka.common.config.SaslConfigs; import org.apache.kafka.common.config.SaslConfigs;
import org.apache.kafka.common.utils.Utils;
import java.io.IOException;
import java.net.URL; import java.net.URL;
import java.util.Locale; import java.util.Locale;
import java.util.Map; import java.util.Map;
@ -36,32 +38,33 @@ import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallb
import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler.CLIENT_SECRET_CONFIG; import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler.CLIENT_SECRET_CONFIG;
import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler.SCOPE_CONFIG; import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler.SCOPE_CONFIG;
public class AccessTokenRetrieverFactory { /**
* {@code DefaultJwtRetriever} instantiates and delegates {@link JwtRetriever} API calls to an embedded implementation
* based on configuration. If {@link SaslConfigs#SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL} is configured with a
* {@code file}-based URL, a {@link FileJwtRetriever} is created and the JWT is expected be contained in the file
* specified. Otherwise, it's assumed to be an HTTP/HTTPS-based URL, so an {@link HttpJwtRetriever} is created.
*/
public class DefaultJwtRetriever implements JwtRetriever {
/** private final Map<String, ?> configs;
* Create an {@link AccessTokenRetriever} from the given SASL and JAAS configuration. private final String saslMechanism;
* private final Map<String, Object> jaasConfig;
* <b>Note</b>: the returned <code>AccessTokenRetriever</code> is <em>not</em> initialized
* here and must be done by the caller prior to use.
*
* @param configs SASL configuration
* @param jaasConfig JAAS configuration
*
* @return Non-<code>null</code> {@link AccessTokenRetriever}
*/
public static AccessTokenRetriever create(Map<String, ?> configs, Map<String, Object> jaasConfig) { private JwtRetriever delegate;
return create(configs, null, jaasConfig);
public DefaultJwtRetriever(Map<String, ?> configs, String saslMechanism, Map<String, Object> jaasConfig) {
this.configs = configs;
this.saslMechanism = saslMechanism;
this.jaasConfig = jaasConfig;
} }
public static AccessTokenRetriever create(Map<String, ?> configs, @Override
String saslMechanism, public void init() throws IOException {
Map<String, Object> jaasConfig) {
ConfigurationUtils cu = new ConfigurationUtils(configs, saslMechanism); ConfigurationUtils cu = new ConfigurationUtils(configs, saslMechanism);
URL tokenEndpointUrl = cu.validateUrl(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL); URL tokenEndpointUrl = cu.validateUrl(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL);
if (tokenEndpointUrl.getProtocol().toLowerCase(Locale.ROOT).equals("file")) { if (tokenEndpointUrl.getProtocol().toLowerCase(Locale.ROOT).equals("file")) {
return new FileTokenRetriever(cu.validateFile(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL)); delegate = new FileJwtRetriever(cu.validateFile(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL));
} else { } else {
JaasOptionsUtils jou = new JaasOptionsUtils(jaasConfig); JaasOptionsUtils jou = new JaasOptionsUtils(jaasConfig);
String clientId = jou.validateString(CLIENT_ID_CONFIG); String clientId = jou.validateString(CLIENT_ID_CONFIG);
@ -75,7 +78,7 @@ public class AccessTokenRetrieverFactory {
boolean urlencodeHeader = validateUrlencodeHeader(cu); boolean urlencodeHeader = validateUrlencodeHeader(cu);
return new HttpAccessTokenRetriever(clientId, delegate = new HttpJwtRetriever(clientId,
clientSecret, clientSecret,
scope, scope,
sslSocketFactory, sslSocketFactory,
@ -86,6 +89,21 @@ public class AccessTokenRetrieverFactory {
cu.validateInteger(SASL_LOGIN_READ_TIMEOUT_MS, false), cu.validateInteger(SASL_LOGIN_READ_TIMEOUT_MS, false),
urlencodeHeader); urlencodeHeader);
} }
delegate.init();
}
@Override
public String retrieve() throws IOException {
if (delegate == null)
throw new IllegalStateException("JWT retriever delegate is null; please call init() first");
return delegate.retrieve();
}
@Override
public void close() throws IOException {
Utils.closeQuietly(delegate, "JWT retriever delegate");
} }
/** /**
@ -96,10 +114,10 @@ public class AccessTokenRetrieverFactory {
* <p/> * <p/>
* *
* This utility method ensures that we have a non-{@code null} value to use in the * This utility method ensures that we have a non-{@code null} value to use in the
* {@link HttpAccessTokenRetriever} constructor. * {@link HttpJwtRetriever} constructor.
*/ */
static boolean validateUrlencodeHeader(ConfigurationUtils configurationUtils) { static boolean validateUrlencodeHeader(ConfigurationUtils configurationUtils) {
Boolean urlencodeHeader = configurationUtils.validateBoolean(SASL_OAUTHBEARER_HEADER_URLENCODE, false); Boolean urlencodeHeader = configurationUtils.get(SASL_OAUTHBEARER_HEADER_URLENCODE);
if (urlencodeHeader != null) if (urlencodeHeader != null)
return urlencodeHeader; return urlencodeHeader;
@ -107,4 +125,7 @@ public class AccessTokenRetrieverFactory {
return DEFAULT_SASL_OAUTHBEARER_HEADER_URLENCODE; return DEFAULT_SASL_OAUTHBEARER_HEADER_URLENCODE;
} }
JwtRetriever delegate() {
return delegate;
}
} }

View File

@ -0,0 +1,108 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.common.security.oauthbearer.internals.secured;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
import org.apache.kafka.common.utils.Utils;
import org.jose4j.keys.resolvers.VerificationKeyResolver;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_CLOCK_SKEW_SECONDS;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_EXPECTED_AUDIENCE;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_EXPECTED_ISSUER;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_SCOPE_CLAIM_NAME;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_SUB_CLAIM_NAME;
/**
* This {@link JwtValidator} uses the delegation approach, instantiating and delegating calls to a
* more concrete implementation. The underlying implementation is determined by the presence/absence
* of the {@link VerificationKeyResolver}: if it's present, a {@link BrokerJwtValidator} is
* created, otherwise a {@link ClientJwtValidator} is created.
*/
public class DefaultJwtValidator implements JwtValidator {
private final Map<String, ?> configs;
private final String saslMechanism;
private final Optional<VerificationKeyResolver> verificationKeyResolver;
private JwtValidator delegate;
public DefaultJwtValidator(Map<String, ?> configs, String saslMechanism) {
this.configs = configs;
this.saslMechanism = saslMechanism;
this.verificationKeyResolver = Optional.empty();
}
public DefaultJwtValidator(Map<String, ?> configs,
String saslMechanism,
VerificationKeyResolver verificationKeyResolver) {
this.configs = configs;
this.saslMechanism = saslMechanism;
this.verificationKeyResolver = Optional.of(verificationKeyResolver);
}
@Override
public void init() throws IOException {
ConfigurationUtils cu = new ConfigurationUtils(configs, saslMechanism);
if (verificationKeyResolver.isPresent()) {
List<String> expectedAudiencesList = cu.get(SASL_OAUTHBEARER_EXPECTED_AUDIENCE);
Set<String> expectedAudiences = expectedAudiencesList != null ? Set.copyOf(expectedAudiencesList) : null;
Integer clockSkew = cu.validateInteger(SASL_OAUTHBEARER_CLOCK_SKEW_SECONDS, false);
String expectedIssuer = cu.validateString(SASL_OAUTHBEARER_EXPECTED_ISSUER, false);
String scopeClaimName = cu.validateString(SASL_OAUTHBEARER_SCOPE_CLAIM_NAME);
String subClaimName = cu.validateString(SASL_OAUTHBEARER_SUB_CLAIM_NAME);
delegate = new BrokerJwtValidator(clockSkew,
expectedAudiences,
expectedIssuer,
verificationKeyResolver.get(),
scopeClaimName,
subClaimName);
} else {
String scopeClaimName = cu.get(SASL_OAUTHBEARER_SCOPE_CLAIM_NAME);
String subClaimName = cu.get(SASL_OAUTHBEARER_SUB_CLAIM_NAME);
delegate = new ClientJwtValidator(scopeClaimName, subClaimName);
}
delegate.init();
}
@Override
public OAuthBearerToken validate(String accessToken) throws ValidateException {
if (delegate == null)
throw new IllegalStateException("JWT validator delegate is null; please call init() first");
return delegate.validate(accessToken);
}
@Override
public void close() throws IOException {
Utils.closeQuietly(delegate, "JWT validator delegate");
}
JwtValidator delegate() {
return delegate;
}
}

View File

@ -23,19 +23,19 @@ import java.io.IOException;
import java.nio.file.Path; import java.nio.file.Path;
/** /**
* <code>FileTokenRetriever</code> is an {@link AccessTokenRetriever} that will load the contents, * <code>FileJwtRetriever</code> is an {@link JwtRetriever} that will load the contents
* interpreting them as a JWT access key in the serialized form. * of a file, interpreting them as a JWT access key in the serialized form.
* *
* @see AccessTokenRetriever * @see JwtRetriever
*/ */
public class FileTokenRetriever implements AccessTokenRetriever { public class FileJwtRetriever implements JwtRetriever {
private final Path accessTokenFile; private final Path accessTokenFile;
private String accessToken; private String accessToken;
public FileTokenRetriever(Path accessTokenFile) { public FileJwtRetriever(Path accessTokenFile) {
this.accessTokenFile = accessTokenFile; this.accessTokenFile = accessTokenFile;
} }

View File

@ -49,22 +49,14 @@ import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.SSLSocketFactory;
/** /**
* <code>HttpAccessTokenRetriever</code> is an {@link AccessTokenRetriever} that will * <code>HttpJwtRetriever</code> is a {@link JwtRetriever} that will communicate with an OAuth/OIDC
* communicate with an OAuth/OIDC provider directly via HTTP to post client credentials * provider directly via HTTP to post client credentials
* ({@link OAuthBearerLoginCallbackHandler#CLIENT_ID_CONFIG}/{@link OAuthBearerLoginCallbackHandler#CLIENT_SECRET_CONFIG}) * ({@link OAuthBearerLoginCallbackHandler#CLIENT_ID_CONFIG}/{@link OAuthBearerLoginCallbackHandler#CLIENT_SECRET_CONFIG})
* to a publicized token endpoint URL * to a publicized token endpoint URL ({@link SaslConfigs#SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL}).
* ({@link SaslConfigs#SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL}).
*
* @see AccessTokenRetriever
* @see OAuthBearerLoginCallbackHandler#CLIENT_ID_CONFIG
* @see OAuthBearerLoginCallbackHandler#CLIENT_SECRET_CONFIG
* @see OAuthBearerLoginCallbackHandler#SCOPE_CONFIG
* @see SaslConfigs#SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL
*/ */
public class HttpJwtRetriever implements JwtRetriever {
public class HttpAccessTokenRetriever implements AccessTokenRetriever { private static final Logger log = LoggerFactory.getLogger(HttpJwtRetriever.class);
private static final Logger log = LoggerFactory.getLogger(HttpAccessTokenRetriever.class);
private static final Set<Integer> UNRETRYABLE_HTTP_CODES; private static final Set<Integer> UNRETRYABLE_HTTP_CODES;
@ -117,16 +109,16 @@ public class HttpAccessTokenRetriever implements AccessTokenRetriever {
private final boolean urlencodeHeader; private final boolean urlencodeHeader;
public HttpAccessTokenRetriever(String clientId, public HttpJwtRetriever(String clientId,
String clientSecret, String clientSecret,
String scope, String scope,
SSLSocketFactory sslSocketFactory, SSLSocketFactory sslSocketFactory,
String tokenEndpointUrl, String tokenEndpointUrl,
long loginRetryBackoffMs, long loginRetryBackoffMs,
long loginRetryBackoffMaxMs, long loginRetryBackoffMaxMs,
Integer loginConnectTimeoutMs, Integer loginConnectTimeoutMs,
Integer loginReadTimeoutMs, Integer loginReadTimeoutMs,
boolean urlencodeHeader) { boolean urlencodeHeader) {
this.clientId = Objects.requireNonNull(clientId); this.clientId = Objects.requireNonNull(clientId);
this.clientSecret = Objects.requireNonNull(clientSecret); this.clientSecret = Objects.requireNonNull(clientSecret);
this.scope = scope; this.scope = scope;

View File

@ -22,8 +22,8 @@ import java.io.IOException;
public interface Initable { public interface Initable {
/** /**
* Lifecycle method to perform any one-time initialization of the retriever. This must * Lifecycle method to perform any one-time initialization of a given resource. This must
* be performed by the caller to ensure the correct state before methods are invoked. * be invoked by the caller to ensure the correct state before methods are invoked.
* *
* @throws IOException Thrown on errors related to IO during initialization * @throws IOException Thrown on errors related to IO during initialization
*/ */
@ -31,5 +31,4 @@ public interface Initable {
default void init() throws IOException { default void init() throws IOException {
// This method left intentionally blank. // This method left intentionally blank.
} }
} }

View File

@ -21,20 +21,20 @@ import java.io.Closeable;
import java.io.IOException; import java.io.IOException;
/** /**
* An <code>AccessTokenRetriever</code> is the internal API by which the login module will * A <code>JwtRetriever</code> is the internal API by which the login module will
* retrieve an access token for use in authorization by the broker. The implementation may * retrieve an access token for use in authorization by the broker. The implementation may
* involve authentication to a remote system, or it can be as simple as loading the contents * involve authentication to a remote system, or it can be as simple as loading the contents
* of a file or configuration setting. * of a file or configuration setting.
* *
* <i>Retrieval</i> is a separate concern from <i>validation</i>, so it isn't necessary for * <i>Retrieval</i> is a separate concern from <i>validation</i>, so it isn't necessary for
* the <code>AccessTokenRetriever</code> implementation to validate the integrity of the JWT * the <code>JwtRetriever</code> implementation to validate the integrity of the JWT
* access token. * access token.
* *
* @see HttpAccessTokenRetriever * @see HttpJwtRetriever
* @see FileTokenRetriever * @see FileJwtRetriever
*/ */
public interface AccessTokenRetriever extends Initable, Closeable { public interface JwtRetriever extends Initable, Closeable {
/** /**
* Retrieves a JWT access token in its serialized three-part form. The implementation * Retrieves a JWT access token in its serialized three-part form. The implementation

View File

@ -19,8 +19,11 @@ package org.apache.kafka.common.security.oauthbearer.internals.secured;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
import java.io.Closeable;
import java.io.IOException;
/** /**
* An instance of <code>AccessTokenValidator</code> acts as a function object that, given an access * An instance of <code>JwtValidator</code> acts as a function object that, given an access
* token in base-64 encoded JWT format, can parse the data, perform validation, and construct an * token in base-64 encoded JWT format, can parse the data, perform validation, and construct an
* {@link OAuthBearerToken} for use by the caller. * {@link OAuthBearerToken} for use by the caller.
* *
@ -40,13 +43,12 @@ import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
* <li><a href="https://datatracker.ietf.org/doc/html/draft-ietf-oauth-access-token-jwt">RFC 6750, Section 2.1</a></li> * <li><a href="https://datatracker.ietf.org/doc/html/draft-ietf-oauth-access-token-jwt">RFC 6750, Section 2.1</a></li>
* </ul> * </ul>
* *
* @see LoginAccessTokenValidator A basic AccessTokenValidator used by client-side login * @see ClientJwtValidator A basic JwtValidator used by client-side login authentication
* authentication * @see BrokerJwtValidator A more robust JwtValidator that is used on the broker to validate the token's
* @see ValidatorAccessTokenValidator A more robust AccessTokenValidator that is used on the broker * contents and verify the signature
* to validate the token's contents and verify the signature
*/ */
public interface AccessTokenValidator { public interface JwtValidator extends Initable, Closeable {
/** /**
* Accepts an OAuth JWT access token in base-64 encoded format, validates, and returns an * Accepts an OAuth JWT access token in base-64 encoded format, validates, and returns an
@ -61,4 +63,10 @@ public interface AccessTokenValidator {
OAuthBearerToken validate(String accessToken) throws ValidateException; OAuthBearerToken validate(String accessToken) throws ValidateException;
/**
* Closes any resources that were initialized by {@link #init()}.
*/
default void close() throws IOException {
// Do nothing...
}
} }

View File

@ -49,12 +49,12 @@ import java.util.concurrent.locks.ReentrantReadWriteLock;
* This instance is created and provided to the * This instance is created and provided to the
* {@link org.jose4j.keys.resolvers.HttpsJwksVerificationKeyResolver} that is used when using * {@link org.jose4j.keys.resolvers.HttpsJwksVerificationKeyResolver} that is used when using
* an HTTP-/HTTPS-based {@link org.jose4j.keys.resolvers.VerificationKeyResolver}, which is then * an HTTP-/HTTPS-based {@link org.jose4j.keys.resolvers.VerificationKeyResolver}, which is then
* provided to the {@link ValidatorAccessTokenValidator} to use in validating the signature of * provided to the {@link BrokerJwtValidator} to use in validating the signature of
* a JWT. * a JWT.
* *
* @see org.jose4j.keys.resolvers.HttpsJwksVerificationKeyResolver * @see org.jose4j.keys.resolvers.HttpsJwksVerificationKeyResolver
* @see org.jose4j.keys.resolvers.VerificationKeyResolver * @see org.jose4j.keys.resolvers.VerificationKeyResolver
* @see ValidatorAccessTokenValidator * @see BrokerJwtValidator
*/ */
public final class RefreshingHttpsJwks implements Initable, Closeable { public final class RefreshingHttpsJwks implements Initable, Closeable {

View File

@ -27,7 +27,7 @@ import javax.security.auth.callback.Callback;
* processing of a {@link javax.security.auth.callback.CallbackHandler#handle(Callback[])}. * processing of a {@link javax.security.auth.callback.CallbackHandler#handle(Callback[])}.
* This error, however, is not thrown from that method directly. * This error, however, is not thrown from that method directly.
* *
* @see AccessTokenValidator#validate(String) * @see JwtValidator#validate(String)
*/ */
public class ValidateException extends KafkaException { public class ValidateException extends KafkaException {

View File

@ -37,7 +37,7 @@ import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_JWKS_E
public class VerificationKeyResolverFactory { public class VerificationKeyResolverFactory {
/** /**
* Create an {@link AccessTokenRetriever} from the given * Create a {@link JwtRetriever} from the given
* {@link org.apache.kafka.common.config.SaslConfigs}. * {@link org.apache.kafka.common.config.SaslConfigs}.
* *
* <b>Note</b>: the returned <code>CloseableVerificationKeyResolver</code> is not * <b>Note</b>: the returned <code>CloseableVerificationKeyResolver</code> is not

View File

@ -21,13 +21,12 @@ import org.apache.kafka.common.config.ConfigException;
import org.apache.kafka.common.security.auth.SaslExtensionsCallback; import org.apache.kafka.common.security.auth.SaslExtensionsCallback;
import org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerClientInitialResponse; import org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerClientInitialResponse;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenBuilder; import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenBuilder;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenRetriever; import org.apache.kafka.common.security.oauthbearer.internals.secured.DefaultJwtRetriever;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidator; import org.apache.kafka.common.security.oauthbearer.internals.secured.DefaultJwtValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidatorFactory; import org.apache.kafka.common.security.oauthbearer.internals.secured.FileJwtRetriever;
import org.apache.kafka.common.security.oauthbearer.internals.secured.FileTokenRetriever; import org.apache.kafka.common.security.oauthbearer.internals.secured.JwtRetriever;
import org.apache.kafka.common.security.oauthbearer.internals.secured.HttpAccessTokenRetriever; import org.apache.kafka.common.security.oauthbearer.internals.secured.JwtValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.OAuthBearerTest; import org.apache.kafka.common.security.oauthbearer.internals.secured.OAuthBearerTest;
import org.apache.kafka.common.utils.Utils;
import org.jose4j.jws.AlgorithmIdentifiers; import org.jose4j.jws.AlgorithmIdentifiers;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
@ -35,9 +34,7 @@ import org.junit.jupiter.api.Test;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.util.Base64;
import java.util.Calendar; import java.util.Calendar;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.TimeZone; import java.util.TimeZone;
@ -50,7 +47,6 @@ import static org.apache.kafka.common.config.internals.BrokerSecurityConfigs.ALL
import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler.CLIENT_ID_CONFIG; import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler.CLIENT_ID_CONFIG;
import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler.CLIENT_SECRET_CONFIG; import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler.CLIENT_SECRET_CONFIG;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
@ -58,6 +54,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assertions.fail;
public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest { public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest {
@AfterEach @AfterEach
public void tearDown() throws Exception { public void tearDown() throws Exception {
System.clearProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG); System.clearProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG);
@ -70,9 +67,10 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest {
.jwk(createRsaJwk()) .jwk(createRsaJwk())
.alg(AlgorithmIdentifiers.RSA_USING_SHA256); .alg(AlgorithmIdentifiers.RSA_USING_SHA256);
String accessToken = builder.build(); String accessToken = builder.build();
AccessTokenRetriever accessTokenRetriever = () -> accessToken; JwtRetriever jwtRetriever = () -> accessToken;
JwtValidator jwtValidator = createJwtValidator(configs);
OAuthBearerLoginCallbackHandler handler = createHandler(accessTokenRetriever, configs); OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
handler.init(Map.of(), jwtRetriever, jwtValidator);
try { try {
OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback(); OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback();
@ -91,7 +89,6 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest {
@Test @Test
public void testHandleSaslExtensionsCallback() throws Exception { public void testHandleSaslExtensionsCallback() throws Exception {
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, "http://www.example.com"); Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, "http://www.example.com");
System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, "http://www.example.com"); System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, "http://www.example.com");
Map<String, Object> jaasConfig = new HashMap<>(); Map<String, Object> jaasConfig = new HashMap<>();
@ -100,7 +97,11 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest {
jaasConfig.put("extension_foo", "1"); jaasConfig.put("extension_foo", "1");
jaasConfig.put("extension_bar", 2); jaasConfig.put("extension_bar", 2);
jaasConfig.put("EXTENSION_baz", "3"); jaasConfig.put("EXTENSION_baz", "3");
configureHandler(handler, configs, jaasConfig);
JwtRetriever jwtRetriever = createJwtRetriever(configs, jaasConfig);
JwtValidator jwtValidator = createJwtValidator(configs);
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
handler.init(jaasConfig, jwtRetriever, jwtValidator);
try { try {
SaslExtensionsCallback callback = new SaslExtensionsCallback(); SaslExtensionsCallback callback = new SaslExtensionsCallback();
@ -121,14 +122,17 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest {
public void testHandleSaslExtensionsCallbackWithInvalidExtension() { public void testHandleSaslExtensionsCallbackWithInvalidExtension() {
String illegalKey = "extension_" + OAuthBearerClientInitialResponse.AUTH_KEY; String illegalKey = "extension_" + OAuthBearerClientInitialResponse.AUTH_KEY;
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, "http://www.example.com"); Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, "http://www.example.com");
System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, "http://www.example.com"); System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, "http://www.example.com");
Map<String, Object> jaasConfig = new HashMap<>(); Map<String, Object> jaasConfig = new HashMap<>();
jaasConfig.put(CLIENT_ID_CONFIG, "an ID"); jaasConfig.put(CLIENT_ID_CONFIG, "an ID");
jaasConfig.put(CLIENT_SECRET_CONFIG, "a secret"); jaasConfig.put(CLIENT_SECRET_CONFIG, "a secret");
jaasConfig.put(illegalKey, "this key isn't allowed per OAuthBearerClientInitialResponse.validateExtensions"); jaasConfig.put(illegalKey, "this key isn't allowed per OAuthBearerClientInitialResponse.validateExtensions");
configureHandler(handler, configs, jaasConfig);
JwtRetriever jwtRetriever = createJwtRetriever(configs, jaasConfig);
JwtValidator jwtValidator = createJwtValidator(configs);
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
handler.init(jaasConfig, jwtRetriever, jwtValidator);
try { try {
SaslExtensionsCallback callback = new SaslExtensionsCallback(); SaslExtensionsCallback callback = new SaslExtensionsCallback();
@ -143,10 +147,10 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest {
@Test @Test
public void testInvalidCallbackGeneratesUnsupportedCallbackException() { public void testInvalidCallbackGeneratesUnsupportedCallbackException() {
Map<String, ?> configs = getSaslConfigs(); Map<String, ?> configs = getSaslConfigs();
JwtRetriever jwtRetriever = () -> "test";
JwtValidator jwtValidator = createJwtValidator(configs);
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler(); OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
AccessTokenRetriever accessTokenRetriever = () -> "foo"; handler.init(Map.of(), jwtRetriever, jwtValidator);
AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(configs);
handler.init(accessTokenRetriever, accessTokenValidator);
try { try {
Callback unsupportedCallback = new Callback() { }; Callback unsupportedCallback = new Callback() { };
@ -166,11 +170,13 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest {
@Test @Test
public void testMissingAccessToken() { public void testMissingAccessToken() {
AccessTokenRetriever accessTokenRetriever = () -> { Map<String, ?> configs = getSaslConfigs();
JwtRetriever jwtRetriever = () -> {
throw new IOException("The token endpoint response access_token value must be non-null"); throw new IOException("The token endpoint response access_token value must be non-null");
}; };
Map<String, ?> configs = getSaslConfigs(); JwtValidator jwtValidator = createJwtValidator(configs);
OAuthBearerLoginCallbackHandler handler = createHandler(accessTokenRetriever, configs); OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
handler.init(Map.of(), jwtRetriever, jwtValidator);
try { try {
OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback(); OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback();
@ -196,7 +202,11 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest {
File accessTokenFile = createTempFile(tmpDir, "access-token-", ".json", withNewline); File accessTokenFile = createTempFile(tmpDir, "access-token-", ".json", withNewline);
Map<String, ?> configs = getSaslConfigs(); Map<String, ?> configs = getSaslConfigs();
OAuthBearerLoginCallbackHandler handler = createHandler(new FileTokenRetriever(accessTokenFile.toPath()), configs); JwtRetriever jwtRetriever = new FileJwtRetriever(accessTokenFile.toPath());
JwtValidator jwtValidator = createJwtValidator(configs);
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
handler.init(Map.of(), jwtRetriever, jwtValidator);
OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback(); OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback();
try { try {
handler.handle(new Callback[]{callback}); handler.handle(new Callback[]{callback});
@ -211,39 +221,15 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest {
@Test @Test
public void testNotConfigured() { public void testNotConfigured() {
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler(); OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
assertThrowsWithMessage(IllegalStateException.class, () -> handler.handle(new Callback[] {}), "first call the configure or init method"); assertThrowsWithMessage(IllegalStateException.class, () -> handler.handle(new Callback[] {}), "first call the configure method");
}
@Test
public void testConfigureWithAccessTokenFile() throws Exception {
String expected = "{}";
File tmpDir = createTempDir("access-token");
File accessTokenFile = createTempFile(tmpDir, "access-token-", ".json", expected);
System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, accessTokenFile.toURI().toString());
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, accessTokenFile.toURI().toString());
Map<String, Object> jaasConfigs = Collections.emptyMap();
configureHandler(handler, configs, jaasConfigs);
assertInstanceOf(FileTokenRetriever.class, handler.getAccessTokenRetriever());
}
@Test
public void testConfigureWithAccessClientCredentials() {
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, "http://www.example.com");
System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, "http://www.example.com");
Map<String, Object> jaasConfigs = new HashMap<>();
jaasConfigs.put(CLIENT_ID_CONFIG, "an ID");
jaasConfigs.put(CLIENT_SECRET_CONFIG, "a secret");
configureHandler(handler, configs, jaasConfigs);
assertInstanceOf(HttpAccessTokenRetriever.class, handler.getAccessTokenRetriever());
} }
private void testInvalidAccessToken(String accessToken, String expectedMessageSubstring) throws Exception { private void testInvalidAccessToken(String accessToken, String expectedMessageSubstring) throws Exception {
Map<String, ?> configs = getSaslConfigs(); Map<String, ?> configs = getSaslConfigs();
OAuthBearerLoginCallbackHandler handler = createHandler(() -> accessToken, configs); JwtRetriever jwtRetriever = () -> accessToken;
JwtValidator jwtValidator = createJwtValidator(configs);
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
handler.init(Map.of(), jwtRetriever, jwtValidator);
try { try {
OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback(); OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback();
@ -260,19 +246,15 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest {
} }
} }
private String createAccessKey(String header, String payload, String signature) { private static DefaultJwtRetriever createJwtRetriever(Map<String, ?> configs) {
Base64.Encoder enc = Base64.getEncoder(); return createJwtRetriever(configs, Map.of());
header = enc.encodeToString(Utils.utf8(header));
payload = enc.encodeToString(Utils.utf8(payload));
signature = enc.encodeToString(Utils.utf8(signature));
return String.format("%s.%s.%s", header, payload, signature);
} }
private OAuthBearerLoginCallbackHandler createHandler(AccessTokenRetriever accessTokenRetriever, Map<String, ?> configs) { private static DefaultJwtRetriever createJwtRetriever(Map<String, ?> configs, Map<String, Object> jaasConfigs) {
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler(); return new DefaultJwtRetriever(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, jaasConfigs);
AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(configs);
handler.init(accessTokenRetriever, accessTokenValidator);
return handler;
} }
private static DefaultJwtValidator createJwtValidator(Map<String, ?> configs) {
return new DefaultJwtValidator(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM);
}
} }

View File

@ -17,27 +17,30 @@
package org.apache.kafka.common.security.oauthbearer; package org.apache.kafka.common.security.oauthbearer;
import org.apache.kafka.common.KafkaException;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenBuilder; import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenBuilder;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidatorFactory;
import org.apache.kafka.common.security.oauthbearer.internals.secured.CloseableVerificationKeyResolver; import org.apache.kafka.common.security.oauthbearer.internals.secured.CloseableVerificationKeyResolver;
import org.apache.kafka.common.security.oauthbearer.internals.secured.DefaultJwtValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.JwtValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.OAuthBearerTest; import org.apache.kafka.common.security.oauthbearer.internals.secured.OAuthBearerTest;
import org.apache.kafka.common.utils.Utils; import org.apache.kafka.common.security.oauthbearer.internals.secured.ValidateException;
import org.jose4j.jws.AlgorithmIdentifiers; import org.jose4j.jws.AlgorithmIdentifiers;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Base64;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import javax.security.auth.callback.Callback; import javax.security.auth.callback.Callback;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_EXPECTED_AUDIENCE; import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_EXPECTED_AUDIENCE;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
public class OAuthBearerValidatorCallbackHandlerTest extends OAuthBearerTest { public class OAuthBearerValidatorCallbackHandlerTest extends OAuthBearerTest {
@ -53,7 +56,10 @@ public class OAuthBearerValidatorCallbackHandlerTest extends OAuthBearerTest {
String accessToken = builder.build(); String accessToken = builder.build();
Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_EXPECTED_AUDIENCE, allAudiences); Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_EXPECTED_AUDIENCE, allAudiences);
OAuthBearerValidatorCallbackHandler handler = createHandler(configs, builder); CloseableVerificationKeyResolver verificationKeyResolver = createVerificationKeyResolver(builder);
JwtValidator jwtValidator = createJwtValidator(configs, verificationKeyResolver);
OAuthBearerValidatorCallbackHandler handler = new OAuthBearerValidatorCallbackHandler();
handler.init(verificationKeyResolver, jwtValidator);
try { try {
OAuthBearerValidatorCallback callback = new OAuthBearerValidatorCallback(accessToken); OAuthBearerValidatorCallback callback = new OAuthBearerValidatorCallback(accessToken);
@ -81,9 +87,68 @@ public class OAuthBearerValidatorCallbackHandlerTest extends OAuthBearerTest {
assertInvalidAccessTokenFails(createAccessKey("{}", "{}", "{}"), substring); assertInvalidAccessTokenFails(createAccessKey("{}", "{}", "{}"), substring);
} }
@Test
public void testHandlerInitThrowsException() throws IOException {
IOException initError = new IOException("init() error");
AccessTokenBuilder builder = new AccessTokenBuilder()
.alg(AlgorithmIdentifiers.RSA_USING_SHA256);
CloseableVerificationKeyResolver verificationKeyResolver = createVerificationKeyResolver(builder);
JwtValidator jwtValidator = new JwtValidator() {
@Override
public void init() throws IOException {
throw initError;
}
@Override
public OAuthBearerToken validate(String accessToken) throws ValidateException {
return null;
}
};
OAuthBearerValidatorCallbackHandler handler = new OAuthBearerValidatorCallbackHandler();
// An error initializing the JwtValidator should cause OAuthBearerValidatorCallbackHandler.init() to fail.
KafkaException root = assertThrows(
KafkaException.class,
() -> handler.init(verificationKeyResolver, jwtValidator)
);
assertNotNull(root.getCause());
assertEquals(initError, root.getCause());
}
@Test
public void testHandlerCloseDoesNotThrowException() throws IOException {
AccessTokenBuilder builder = new AccessTokenBuilder()
.alg(AlgorithmIdentifiers.RSA_USING_SHA256);
CloseableVerificationKeyResolver verificationKeyResolver = createVerificationKeyResolver(builder);
JwtValidator jwtValidator = new JwtValidator() {
@Override
public void close() throws IOException {
throw new IOException("close() error");
}
@Override
public OAuthBearerToken validate(String accessToken) throws ValidateException {
return null;
}
};
OAuthBearerValidatorCallbackHandler handler = new OAuthBearerValidatorCallbackHandler();
handler.init(verificationKeyResolver, jwtValidator);
// An error closings the JwtValidator should *not* cause OAuthBearerValidatorCallbackHandler.close() to fail.
assertDoesNotThrow(handler::close);
}
private void assertInvalidAccessTokenFails(String accessToken, String expectedMessageSubstring) throws Exception { private void assertInvalidAccessTokenFails(String accessToken, String expectedMessageSubstring) throws Exception {
AccessTokenBuilder builder = new AccessTokenBuilder()
.alg(AlgorithmIdentifiers.RSA_USING_SHA256);
Map<String, ?> configs = getSaslConfigs(); Map<String, ?> configs = getSaslConfigs();
OAuthBearerValidatorCallbackHandler handler = createHandler(configs, new AccessTokenBuilder()); CloseableVerificationKeyResolver verificationKeyResolver = createVerificationKeyResolver(builder);
JwtValidator jwtValidator = createJwtValidator(configs, verificationKeyResolver);
OAuthBearerValidatorCallbackHandler handler = new OAuthBearerValidatorCallbackHandler();
handler.init(verificationKeyResolver, jwtValidator);
try { try {
OAuthBearerValidatorCallback callback = new OAuthBearerValidatorCallback(accessToken); OAuthBearerValidatorCallback callback = new OAuthBearerValidatorCallback(accessToken);
@ -98,22 +163,11 @@ public class OAuthBearerValidatorCallbackHandlerTest extends OAuthBearerTest {
} }
} }
private OAuthBearerValidatorCallbackHandler createHandler(Map<String, ?> options, private JwtValidator createJwtValidator(Map<String, ?> configs, CloseableVerificationKeyResolver verificationKeyResolver) {
AccessTokenBuilder builder) { return new DefaultJwtValidator(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, verificationKeyResolver);
OAuthBearerValidatorCallbackHandler handler = new OAuthBearerValidatorCallbackHandler();
CloseableVerificationKeyResolver verificationKeyResolver = (jws, nestingContext) ->
builder.jwk().getPublicKey();
AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(options, verificationKeyResolver);
handler.init(verificationKeyResolver, accessTokenValidator);
return handler;
} }
private String createAccessKey(String header, String payload, String signature) { private CloseableVerificationKeyResolver createVerificationKeyResolver(AccessTokenBuilder builder) {
Base64.Encoder enc = Base64.getEncoder(); return (jws, nestingContext) -> builder.jwk().getPublicKey();
header = enc.encodeToString(Utils.utf8(header));
payload = enc.encodeToString(Utils.utf8(payload));
signature = enc.encodeToString(Utils.utf8(signature));
return String.format("%s.%s.%s", header, payload, signature);
} }
} }

View File

@ -1,73 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.common.security.oauthbearer.internals.secured;
import org.apache.kafka.common.KafkaException;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler;
import org.junit.jupiter.api.Test;
import java.io.IOException;
import java.util.Map;
public class AccessTokenValidatorFactoryTest extends OAuthBearerTest {
@Test
public void testConfigureThrowsExceptionOnAccessTokenValidatorInit() {
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
AccessTokenRetriever accessTokenRetriever = new AccessTokenRetriever() {
@Override
public void init() throws IOException {
throw new IOException("My init had an error!");
}
@Override
public String retrieve() {
return "dummy";
}
};
Map<String, ?> configs = getSaslConfigs();
AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(configs);
assertThrowsWithMessage(
KafkaException.class, () -> handler.init(accessTokenRetriever, accessTokenValidator), "encountered an error when initializing");
}
@Test
public void testConfigureThrowsExceptionOnAccessTokenValidatorClose() {
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
AccessTokenRetriever accessTokenRetriever = new AccessTokenRetriever() {
@Override
public void close() throws IOException {
throw new IOException("My close had an error!");
}
@Override
public String retrieve() {
return "dummy";
}
};
Map<String, ?> configs = getSaslConfigs();
AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(configs);
handler.init(accessTokenRetriever, accessTokenValidator);
// Basically asserting this doesn't throw an exception :(
handler.close();
}
}

View File

@ -28,11 +28,11 @@ import java.util.Collections;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class ValidatorAccessTokenValidatorTest extends AccessTokenValidatorTest { public class BrokerJwtValidatorTest extends JwtValidatorTest {
@Override @Override
protected AccessTokenValidator createAccessTokenValidator(AccessTokenBuilder builder) { protected JwtValidator createJwtValidator(AccessTokenBuilder builder) {
return new ValidatorAccessTokenValidator(30, return new BrokerJwtValidator(30,
Collections.emptySet(), Collections.emptySet(),
null, null,
(jws, nestingContext) -> builder.jwk().getKey(), (jws, nestingContext) -> builder.jwk().getKey(),
@ -72,7 +72,7 @@ public class ValidatorAccessTokenValidatorTest extends AccessTokenValidatorTest
.addCustomClaim(subClaimName, subject) .addCustomClaim(subClaimName, subject)
.subjectClaimName(subClaimName) .subjectClaimName(subClaimName)
.subject(null); .subject(null);
AccessTokenValidator validator = createAccessTokenValidator(tokenBuilder); JwtValidator validator = createJwtValidator(tokenBuilder);
// Validation should succeed (e.g. signature verification) even if sub claim is missing // Validation should succeed (e.g. signature verification) even if sub claim is missing
OAuthBearerToken token = validator.validate(tokenBuilder.build()); OAuthBearerToken token = validator.validate(tokenBuilder.build());
@ -82,7 +82,7 @@ public class ValidatorAccessTokenValidatorTest extends AccessTokenValidatorTest
private void testEncryptionAlgorithm(PublicJsonWebKey jwk, String alg) throws Exception { private void testEncryptionAlgorithm(PublicJsonWebKey jwk, String alg) throws Exception {
AccessTokenBuilder builder = new AccessTokenBuilder().jwk(jwk).alg(alg); AccessTokenBuilder builder = new AccessTokenBuilder().jwk(jwk).alg(alg);
AccessTokenValidator validator = createAccessTokenValidator(builder); JwtValidator validator = createJwtValidator(builder);
String accessToken = builder.build(); String accessToken = builder.build();
OAuthBearerToken token = validator.validate(accessToken); OAuthBearerToken token = validator.validate(accessToken);

View File

@ -17,11 +17,11 @@
package org.apache.kafka.common.security.oauthbearer.internals.secured; package org.apache.kafka.common.security.oauthbearer.internals.secured;
public class LoginAccessTokenValidatorTest extends AccessTokenValidatorTest { public class ClientJwtValidatorTest extends JwtValidatorTest {
@Override @Override
protected AccessTokenValidator createAccessTokenValidator(AccessTokenBuilder builder) { protected JwtValidator createJwtValidator(AccessTokenBuilder builder) {
return new LoginAccessTokenValidator(builder.scopeClaimName(), builder.subjectClaimName()); return new ClientJwtValidator(builder.scopeClaimName(), builder.subjectClaimName());
} }
} }

View File

@ -18,6 +18,7 @@
package org.apache.kafka.common.security.oauthbearer.internals.secured; package org.apache.kafka.common.security.oauthbearer.internals.secured;
import org.apache.kafka.common.config.ConfigException; import org.apache.kafka.common.config.ConfigException;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -26,7 +27,9 @@ import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import java.io.File; import java.io.File;
import java.io.IOException;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.stream.Stream; import java.util.stream.Stream;
@ -34,9 +37,13 @@ import static org.apache.kafka.common.config.SaslConfigs.DEFAULT_SASL_OAUTHBEARE
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_HEADER_URLENCODE; import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_HEADER_URLENCODE;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL; import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL;
import static org.apache.kafka.common.config.internals.BrokerSecurityConfigs.ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG; import static org.apache.kafka.common.config.internals.BrokerSecurityConfigs.ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG;
import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler.CLIENT_ID_CONFIG;
import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler.CLIENT_SECRET_CONFIG;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
public class AccessTokenRetrieverFactoryTest extends OAuthBearerTest { public class DefaultJwtRetrieverTest extends OAuthBearerTest {
@AfterEach @AfterEach
public void tearDown() throws Exception { public void tearDown() throws Exception {
@ -44,7 +51,7 @@ public class AccessTokenRetrieverFactoryTest extends OAuthBearerTest {
} }
@Test @Test
public void testConfigureRefreshingFileAccessTokenRetriever() throws Exception { public void testConfigureRefreshingFileJwtRetriever() throws Exception {
String expected = "{}"; String expected = "{}";
File tmpDir = createTempDir("access-token"); File tmpDir = createTempDir("access-token");
@ -54,31 +61,37 @@ public class AccessTokenRetrieverFactoryTest extends OAuthBearerTest {
Map<String, ?> configs = Collections.singletonMap(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, accessTokenFile.toURI().toString()); Map<String, ?> configs = Collections.singletonMap(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, accessTokenFile.toURI().toString());
Map<String, Object> jaasConfig = Collections.emptyMap(); Map<String, Object> jaasConfig = Collections.emptyMap();
try (AccessTokenRetriever accessTokenRetriever = AccessTokenRetrieverFactory.create(configs, jaasConfig)) { try (JwtRetriever jwtRetriever = new DefaultJwtRetriever(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, jaasConfig)) {
accessTokenRetriever.init(); jwtRetriever.init();
assertEquals(expected, accessTokenRetriever.retrieve()); assertEquals(expected, jwtRetriever.retrieve());
} }
} }
@Test @Test
public void testConfigureRefreshingFileAccessTokenRetrieverWithInvalidDirectory() { public void testConfigureRefreshingFileJwtRetrieverWithInvalidDirectory() throws IOException {
// Should fail because the parent path doesn't exist. // Should fail because the parent path doesn't exist.
String file = new File("/tmp/this-directory-does-not-exist/foo.json").toURI().toString(); String file = new File("/tmp/this-directory-does-not-exist/foo.json").toURI().toString();
System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, file); System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, file);
Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, file); Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, file);
Map<String, Object> jaasConfig = Collections.emptyMap(); Map<String, Object> jaasConfig = Collections.emptyMap();
assertThrowsWithMessage(ConfigException.class, () -> AccessTokenRetrieverFactory.create(configs, jaasConfig), "that doesn't exist");
try (JwtRetriever jwtRetriever = new DefaultJwtRetriever(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, jaasConfig)) {
assertThrowsWithMessage(ConfigException.class, jwtRetriever::init, "that doesn't exist");
}
} }
@Test @Test
public void testConfigureRefreshingFileAccessTokenRetrieverWithInvalidFile() throws Exception { public void testConfigureRefreshingFileJwtRetrieverWithInvalidFile() throws Exception {
// Should fail because while the parent path exists, the file itself doesn't. // Should fail because while the parent path exists, the file itself doesn't.
File tmpDir = createTempDir("this-directory-does-exist"); File tmpDir = createTempDir("this-directory-does-exist");
File accessTokenFile = new File(tmpDir, "this-file-does-not-exist.json"); File accessTokenFile = new File(tmpDir, "this-file-does-not-exist.json");
System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, accessTokenFile.toURI().toString()); System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, accessTokenFile.toURI().toString());
Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, accessTokenFile.toURI().toString()); Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, accessTokenFile.toURI().toString());
Map<String, Object> jaasConfig = Collections.emptyMap(); Map<String, Object> jaasConfig = Collections.emptyMap();
assertThrowsWithMessage(ConfigException.class, () -> AccessTokenRetrieverFactory.create(configs, jaasConfig), "that doesn't exist");
try (JwtRetriever jwtRetriever = new DefaultJwtRetriever(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, jaasConfig)) {
assertThrowsWithMessage(ConfigException.class, jwtRetriever::init, "that doesn't exist");
}
} }
@Test @Test
@ -87,15 +100,53 @@ public class AccessTokenRetrieverFactoryTest extends OAuthBearerTest {
File tmpDir = createTempDir("not_allowed"); File tmpDir = createTempDir("not_allowed");
File accessTokenFile = new File(tmpDir, "not_allowed.json"); File accessTokenFile = new File(tmpDir, "not_allowed.json");
Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, accessTokenFile.toURI().toString()); Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, accessTokenFile.toURI().toString());
assertThrowsWithMessage(ConfigException.class, () -> AccessTokenRetrieverFactory.create(configs, Collections.emptyMap()),
ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG); try (JwtRetriever jwtRetriever = new DefaultJwtRetriever(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, Collections.emptyMap())) {
assertThrowsWithMessage(ConfigException.class, jwtRetriever::init, ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG);
}
}
@Test
public void testConfigureWithAccessTokenFile() throws Exception {
String expected = "{}";
File tmpDir = createTempDir("access-token");
File accessTokenFile = createTempFile(tmpDir, "access-token-", ".json", expected);
System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, accessTokenFile.toURI().toString());
Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, accessTokenFile.toURI().toString());
DefaultJwtRetriever jwtRetriever = new DefaultJwtRetriever(
configs,
OAuthBearerLoginModule.OAUTHBEARER_MECHANISM,
Map.of()
);
assertDoesNotThrow(jwtRetriever::init);
assertInstanceOf(FileJwtRetriever.class, jwtRetriever.delegate());
}
@Test
public void testConfigureWithAccessClientCredentials() {
Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, "http://www.example.com");
System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, "http://www.example.com");
Map<String, Object> jaasConfigs = new HashMap<>();
jaasConfigs.put(CLIENT_ID_CONFIG, "an ID");
jaasConfigs.put(CLIENT_SECRET_CONFIG, "a secret");
DefaultJwtRetriever jwtRetriever = new DefaultJwtRetriever(
configs,
OAuthBearerLoginModule.OAUTHBEARER_MECHANISM,
jaasConfigs
);
assertDoesNotThrow(jwtRetriever::init);
assertInstanceOf(HttpJwtRetriever.class, jwtRetriever.delegate());
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("urlencodeHeaderSupplier") @MethodSource("urlencodeHeaderSupplier")
public void testUrlencodeHeader(Map<String, Object> configs, boolean expectedValue) { public void testUrlencodeHeader(Map<String, Object> configs, boolean expectedValue) {
ConfigurationUtils cu = new ConfigurationUtils(configs); ConfigurationUtils cu = new ConfigurationUtils(configs);
boolean actualValue = AccessTokenRetrieverFactory.validateUrlencodeHeader(cu); boolean actualValue = DefaultJwtRetriever.validateUrlencodeHeader(cu);
assertEquals(expectedValue, actualValue); assertEquals(expectedValue, actualValue);
} }

View File

@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.common.security.oauthbearer.internals.secured;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
import org.jose4j.jws.AlgorithmIdentifiers;
import org.junit.jupiter.api.Test;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
public class DefaultJwtValidatorTest extends OAuthBearerTest {
@Test
public void testConfigureWithVerificationKeyResolver() {
AccessTokenBuilder builder = new AccessTokenBuilder()
.alg(AlgorithmIdentifiers.RSA_USING_SHA256);
CloseableVerificationKeyResolver verificationKeyResolver = createVerificationKeyResolver(builder);
Map<String, ?> configs = getSaslConfigs();
DefaultJwtValidator jwtValidator = new DefaultJwtValidator(
configs,
OAuthBearerLoginModule.OAUTHBEARER_MECHANISM,
verificationKeyResolver
);
assertDoesNotThrow(jwtValidator::init);
assertInstanceOf(BrokerJwtValidator.class, jwtValidator.delegate());
}
@Test
public void testConfigureWithoutVerificationKeyResolver() {
Map<String, ?> configs = getSaslConfigs();
DefaultJwtValidator jwtValidator = new DefaultJwtValidator(
configs,
OAuthBearerLoginModule.OAUTHBEARER_MECHANISM
);
assertDoesNotThrow(jwtValidator::init);
assertInstanceOf(ClientJwtValidator.class, jwtValidator.delegate());
}
private CloseableVerificationKeyResolver createVerificationKeyResolver(AccessTokenBuilder builder) {
return (jws, nestingContext) -> builder.jwk().getPublicKey();
}
}

View File

@ -39,20 +39,20 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
public class HttpAccessTokenRetrieverTest extends OAuthBearerTest { public class HttpJwtRetrieverTest extends OAuthBearerTest {
@Test @Test
public void test() throws IOException { public void test() throws IOException {
String expectedResponse = "Hiya, buddy"; String expectedResponse = "Hiya, buddy";
HttpURLConnection mockedCon = createHttpURLConnection(expectedResponse); HttpURLConnection mockedCon = createHttpURLConnection(expectedResponse);
String response = HttpAccessTokenRetriever.post(mockedCon, null, null, null, null); String response = HttpJwtRetriever.post(mockedCon, null, null, null, null);
assertEquals(expectedResponse, response); assertEquals(expectedResponse, response);
} }
@Test @Test
public void testEmptyResponse() throws IOException { public void testEmptyResponse() throws IOException {
HttpURLConnection mockedCon = createHttpURLConnection(""); HttpURLConnection mockedCon = createHttpURLConnection("");
assertThrows(IOException.class, () -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null)); assertThrows(IOException.class, () -> HttpJwtRetriever.post(mockedCon, null, null, null, null));
} }
@Test @Test
@ -60,7 +60,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
HttpURLConnection mockedCon = createHttpURLConnection("dummy"); HttpURLConnection mockedCon = createHttpURLConnection("dummy");
when(mockedCon.getInputStream()).thenThrow(new IOException("Can't read")); when(mockedCon.getInputStream()).thenThrow(new IOException("Can't read"));
assertThrows(IOException.class, () -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null)); assertThrows(IOException.class, () -> HttpJwtRetriever.post(mockedCon, null, null, null, null));
} }
@Test @Test
@ -72,7 +72,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
.getBytes(StandardCharsets.UTF_8))); .getBytes(StandardCharsets.UTF_8)));
when(mockedCon.getResponseCode()).thenReturn(HttpURLConnection.HTTP_BAD_REQUEST); when(mockedCon.getResponseCode()).thenReturn(HttpURLConnection.HTTP_BAD_REQUEST);
UnretryableException ioe = assertThrows(UnretryableException.class, UnretryableException ioe = assertThrows(UnretryableException.class,
() -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null)); () -> HttpJwtRetriever.post(mockedCon, null, null, null, null));
assertTrue(ioe.getMessage().contains("{\"some_arg\" - \"some problem with arg\"}")); assertTrue(ioe.getMessage().contains("{\"some_arg\" - \"some problem with arg\"}"));
} }
@ -85,7 +85,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
.getBytes(StandardCharsets.UTF_8))); .getBytes(StandardCharsets.UTF_8)));
when(mockedCon.getResponseCode()).thenReturn(HttpURLConnection.HTTP_INTERNAL_ERROR); when(mockedCon.getResponseCode()).thenReturn(HttpURLConnection.HTTP_INTERNAL_ERROR);
IOException ioe = assertThrows(IOException.class, IOException ioe = assertThrows(IOException.class,
() -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null)); () -> HttpJwtRetriever.post(mockedCon, null, null, null, null));
assertTrue(ioe.getMessage().contains("{\"some_arg\" - \"some problem with arg\"}")); assertTrue(ioe.getMessage().contains("{\"some_arg\" - \"some problem with arg\"}"));
// error response body has different keys // error response body has different keys
@ -93,7 +93,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
"{\"errorCode\":\"some_arg\", \"errorSummary\":\"some problem with arg\"}" "{\"errorCode\":\"some_arg\", \"errorSummary\":\"some problem with arg\"}"
.getBytes(StandardCharsets.UTF_8))); .getBytes(StandardCharsets.UTF_8)));
ioe = assertThrows(IOException.class, ioe = assertThrows(IOException.class,
() -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null)); () -> HttpJwtRetriever.post(mockedCon, null, null, null, null));
assertTrue(ioe.getMessage().contains("{\"some_arg\" - \"some problem with arg\"}")); assertTrue(ioe.getMessage().contains("{\"some_arg\" - \"some problem with arg\"}"));
// error response is valid json but unknown keys // error response is valid json but unknown keys
@ -101,7 +101,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
"{\"err\":\"some_arg\", \"err_des\":\"some problem with arg\"}" "{\"err\":\"some_arg\", \"err_des\":\"some problem with arg\"}"
.getBytes(StandardCharsets.UTF_8))); .getBytes(StandardCharsets.UTF_8)));
ioe = assertThrows(IOException.class, ioe = assertThrows(IOException.class,
() -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null)); () -> HttpJwtRetriever.post(mockedCon, null, null, null, null));
assertTrue(ioe.getMessage().contains("{\"err\":\"some_arg\", \"err_des\":\"some problem with arg\"}")); assertTrue(ioe.getMessage().contains("{\"err\":\"some_arg\", \"err_des\":\"some problem with arg\"}"));
} }
@ -113,7 +113,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
"non json error output".getBytes(StandardCharsets.UTF_8))); "non json error output".getBytes(StandardCharsets.UTF_8)));
when(mockedCon.getResponseCode()).thenReturn(HttpURLConnection.HTTP_INTERNAL_ERROR); when(mockedCon.getResponseCode()).thenReturn(HttpURLConnection.HTTP_INTERNAL_ERROR);
IOException ioe = assertThrows(IOException.class, IOException ioe = assertThrows(IOException.class,
() -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null)); () -> HttpJwtRetriever.post(mockedCon, null, null, null, null));
assertTrue(ioe.getMessage().contains("{non json error output}")); assertTrue(ioe.getMessage().contains("{non json error output}"));
} }
@ -124,7 +124,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
r.nextBytes(expected); r.nextBytes(expected);
InputStream in = new ByteArrayInputStream(expected); InputStream in = new ByteArrayInputStream(expected);
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
HttpAccessTokenRetriever.copy(in, out); HttpJwtRetriever.copy(in, out);
assertArrayEquals(expected, out.toByteArray()); assertArrayEquals(expected, out.toByteArray());
} }
@ -133,7 +133,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
InputStream mockedIn = mock(InputStream.class); InputStream mockedIn = mock(InputStream.class);
OutputStream out = new ByteArrayOutputStream(); OutputStream out = new ByteArrayOutputStream();
when(mockedIn.read(any(byte[].class))).thenThrow(new IOException()); when(mockedIn.read(any(byte[].class))).thenThrow(new IOException());
assertThrows(IOException.class, () -> HttpAccessTokenRetriever.copy(mockedIn, out)); assertThrows(IOException.class, () -> HttpJwtRetriever.copy(mockedIn, out));
} }
@Test @Test
@ -143,7 +143,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
ObjectNode node = mapper.createObjectNode(); ObjectNode node = mapper.createObjectNode();
node.put("access_token", expected); node.put("access_token", expected);
String actual = HttpAccessTokenRetriever.parseAccessToken(mapper.writeValueAsString(node)); String actual = HttpJwtRetriever.parseAccessToken(mapper.writeValueAsString(node));
assertEquals(expected, actual); assertEquals(expected, actual);
} }
@ -153,7 +153,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
ObjectNode node = mapper.createObjectNode(); ObjectNode node = mapper.createObjectNode();
node.put("access_token", ""); node.put("access_token", "");
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.parseAccessToken(mapper.writeValueAsString(node))); assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.parseAccessToken(mapper.writeValueAsString(node)));
} }
@Test @Test
@ -162,12 +162,12 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
ObjectNode node = mapper.createObjectNode(); ObjectNode node = mapper.createObjectNode();
node.put("sub", "jdoe"); node.put("sub", "jdoe");
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.parseAccessToken(mapper.writeValueAsString(node))); assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.parseAccessToken(mapper.writeValueAsString(node)));
} }
@Test @Test
public void testParseAccessTokenInvalidJson() { public void testParseAccessTokenInvalidJson() {
assertThrows(IOException.class, () -> HttpAccessTokenRetriever.parseAccessToken("not valid JSON")); assertThrows(IOException.class, () -> HttpJwtRetriever.parseAccessToken("not valid JSON"));
} }
@Test @Test
@ -184,27 +184,27 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
} }
private void assertAuthorizationHeader(String clientId, String clientSecret, boolean urlencode, String expected) { private void assertAuthorizationHeader(String clientId, String clientSecret, boolean urlencode, String expected) {
String actual = HttpAccessTokenRetriever.formatAuthorizationHeader(clientId, clientSecret, urlencode); String actual = HttpJwtRetriever.formatAuthorizationHeader(clientId, clientSecret, urlencode);
assertEquals(expected, actual, String.format("Expected the HTTP Authorization header generated for client ID \"%s\" and client secret \"%s\" to match", clientId, clientSecret)); assertEquals(expected, actual, String.format("Expected the HTTP Authorization header generated for client ID \"%s\" and client secret \"%s\" to match", clientId, clientSecret));
} }
@Test @Test
public void testFormatAuthorizationHeaderMissingValues() { public void testFormatAuthorizationHeaderMissingValues() {
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader(null, "secret", false)); assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader(null, "secret", false));
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader("id", null, false)); assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader("id", null, false));
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader(null, null, false)); assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader(null, null, false));
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader("", "secret", false)); assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader("", "secret", false));
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader("id", "", false)); assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader("id", "", false));
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader("", "", false)); assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader("", "", false));
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader(" ", "secret", false)); assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader(" ", "secret", false));
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader("id", " ", false)); assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader("id", " ", false));
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader(" ", " ", false)); assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader(" ", " ", false));
} }
@Test @Test
public void testFormatRequestBody() { public void testFormatRequestBody() {
String expected = "grant_type=client_credentials&scope=scope"; String expected = "grant_type=client_credentials&scope=scope";
String actual = HttpAccessTokenRetriever.formatRequestBody("scope"); String actual = HttpJwtRetriever.formatRequestBody("scope");
assertEquals(expected, actual); assertEquals(expected, actual);
} }
@ -214,24 +214,24 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
String exclamationMark = "%21"; String exclamationMark = "%21";
String expected = String.format("grant_type=client_credentials&scope=earth+is+great%s", exclamationMark); String expected = String.format("grant_type=client_credentials&scope=earth+is+great%s", exclamationMark);
String actual = HttpAccessTokenRetriever.formatRequestBody("earth is great!"); String actual = HttpJwtRetriever.formatRequestBody("earth is great!");
assertEquals(expected, actual); assertEquals(expected, actual);
expected = String.format("grant_type=client_credentials&scope=what+on+earth%s%s%s%s%s", questionMark, exclamationMark, questionMark, exclamationMark, questionMark); expected = String.format("grant_type=client_credentials&scope=what+on+earth%s%s%s%s%s", questionMark, exclamationMark, questionMark, exclamationMark, questionMark);
actual = HttpAccessTokenRetriever.formatRequestBody("what on earth?!?!?"); actual = HttpJwtRetriever.formatRequestBody("what on earth?!?!?");
assertEquals(expected, actual); assertEquals(expected, actual);
} }
@Test @Test
public void testFormatRequestBodyMissingValues() { public void testFormatRequestBodyMissingValues() {
String expected = "grant_type=client_credentials"; String expected = "grant_type=client_credentials";
String actual = HttpAccessTokenRetriever.formatRequestBody(null); String actual = HttpJwtRetriever.formatRequestBody(null);
assertEquals(expected, actual); assertEquals(expected, actual);
actual = HttpAccessTokenRetriever.formatRequestBody(""); actual = HttpJwtRetriever.formatRequestBody("");
assertEquals(expected, actual); assertEquals(expected, actual);
actual = HttpAccessTokenRetriever.formatRequestBody(" "); actual = HttpJwtRetriever.formatRequestBody(" ");
assertEquals(expected, actual); assertEquals(expected, actual);
} }

View File

@ -26,42 +26,42 @@ import org.junit.jupiter.api.TestInstance.Lifecycle;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
@TestInstance(Lifecycle.PER_CLASS) @TestInstance(Lifecycle.PER_CLASS)
public abstract class AccessTokenValidatorTest extends OAuthBearerTest { public abstract class JwtValidatorTest extends OAuthBearerTest {
protected abstract AccessTokenValidator createAccessTokenValidator(AccessTokenBuilder accessTokenBuilder) throws Exception; protected abstract JwtValidator createJwtValidator(AccessTokenBuilder accessTokenBuilder) throws Exception;
protected AccessTokenValidator createAccessTokenValidator() throws Exception { protected JwtValidator createJwtValidator() throws Exception {
AccessTokenBuilder builder = new AccessTokenBuilder(); AccessTokenBuilder builder = new AccessTokenBuilder();
return createAccessTokenValidator(builder); return createJwtValidator(builder);
} }
@Test @Test
public void testNull() throws Exception { public void testNull() throws Exception {
AccessTokenValidator validator = createAccessTokenValidator(); JwtValidator validator = createJwtValidator();
assertThrowsWithMessage(ValidateException.class, () -> validator.validate(null), "Malformed JWT provided; expected three sections (header, payload, and signature)"); assertThrowsWithMessage(ValidateException.class, () -> validator.validate(null), "Malformed JWT provided; expected three sections (header, payload, and signature)");
} }
@Test @Test
public void testEmptyString() throws Exception { public void testEmptyString() throws Exception {
AccessTokenValidator validator = createAccessTokenValidator(); JwtValidator validator = createJwtValidator();
assertThrowsWithMessage(ValidateException.class, () -> validator.validate(""), "Malformed JWT provided; expected three sections (header, payload, and signature)"); assertThrowsWithMessage(ValidateException.class, () -> validator.validate(""), "Malformed JWT provided; expected three sections (header, payload, and signature)");
} }
@Test @Test
public void testWhitespace() throws Exception { public void testWhitespace() throws Exception {
AccessTokenValidator validator = createAccessTokenValidator(); JwtValidator validator = createJwtValidator();
assertThrowsWithMessage(ValidateException.class, () -> validator.validate(" "), "Malformed JWT provided; expected three sections (header, payload, and signature)"); assertThrowsWithMessage(ValidateException.class, () -> validator.validate(" "), "Malformed JWT provided; expected three sections (header, payload, and signature)");
} }
@Test @Test
public void testEmptySections() throws Exception { public void testEmptySections() throws Exception {
AccessTokenValidator validator = createAccessTokenValidator(); JwtValidator validator = createJwtValidator();
assertThrowsWithMessage(ValidateException.class, () -> validator.validate(".."), "Malformed JWT provided; expected three sections (header, payload, and signature)"); assertThrowsWithMessage(ValidateException.class, () -> validator.validate(".."), "Malformed JWT provided; expected three sections (header, payload, and signature)");
} }
@Test @Test
public void testMissingHeader() throws Exception { public void testMissingHeader() throws Exception {
AccessTokenValidator validator = createAccessTokenValidator(); JwtValidator validator = createJwtValidator();
String header = ""; String header = "";
String payload = createBase64JsonJwtSection(node -> { }); String payload = createBase64JsonJwtSection(node -> { });
String signature = ""; String signature = "";
@ -71,7 +71,7 @@ public abstract class AccessTokenValidatorTest extends OAuthBearerTest {
@Test @Test
public void testMissingPayload() throws Exception { public void testMissingPayload() throws Exception {
AccessTokenValidator validator = createAccessTokenValidator(); JwtValidator validator = createJwtValidator();
String header = createBase64JsonJwtSection(node -> node.put(HeaderParameterNames.ALGORITHM, AlgorithmIdentifiers.NONE)); String header = createBase64JsonJwtSection(node -> node.put(HeaderParameterNames.ALGORITHM, AlgorithmIdentifiers.NONE));
String payload = ""; String payload = "";
String signature = ""; String signature = "";
@ -81,7 +81,7 @@ public abstract class AccessTokenValidatorTest extends OAuthBearerTest {
@Test @Test
public void testMissingSignature() throws Exception { public void testMissingSignature() throws Exception {
AccessTokenValidator validator = createAccessTokenValidator(); JwtValidator validator = createJwtValidator();
String header = createBase64JsonJwtSection(node -> node.put(HeaderParameterNames.ALGORITHM, AlgorithmIdentifiers.NONE)); String header = createBase64JsonJwtSection(node -> node.put(HeaderParameterNames.ALGORITHM, AlgorithmIdentifiers.NONE));
String payload = createBase64JsonJwtSection(node -> { }); String payload = createBase64JsonJwtSection(node -> { });
String signature = ""; String signature = "";

View File

@ -19,9 +19,6 @@ package org.apache.kafka.common.security.oauthbearer.internals.secured;
import org.apache.kafka.common.config.AbstractConfig; import org.apache.kafka.common.config.AbstractConfig;
import org.apache.kafka.common.config.ConfigDef; import org.apache.kafka.common.config.ConfigDef;
import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
import org.apache.kafka.common.security.authenticator.TestJaasConfig;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
import org.apache.kafka.common.utils.Utils; import org.apache.kafka.common.utils.Utils;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
@ -52,8 +49,6 @@ import java.util.Map;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import java.util.function.Consumer; import java.util.function.Consumer;
import javax.security.auth.login.AppConfigurationEntry;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assertions.fail;
@ -80,18 +75,6 @@ public abstract class OAuthBearerTest {
expectedSubstring)); expectedSubstring));
} }
protected void configureHandler(AuthenticateCallbackHandler handler,
Map<String, ?> configs,
Map<String, Object> jaasConfig) {
TestJaasConfig config = new TestJaasConfig();
config.createOrUpdateEntry("KafkaClient", OAuthBearerLoginModule.class.getName(), jaasConfig);
AppConfigurationEntry kafkaClient = config.getAppConfigurationEntry("KafkaClient")[0];
handler.configure(configs,
OAuthBearerLoginModule.OAUTHBEARER_MECHANISM,
Collections.singletonList(kafkaClient));
}
protected String createBase64JsonJwtSection(Consumer<ObjectNode> c) { protected String createBase64JsonJwtSection(Consumer<ObjectNode> c) {
String json = createJsonJwtSection(c); String json = createJsonJwtSection(c);
@ -212,4 +195,11 @@ public abstract class OAuthBearerTest {
return jwk; return jwk;
} }
protected String createAccessKey(String header, String payload, String signature) {
Base64.Encoder enc = Base64.getEncoder();
header = enc.encodeToString(Utils.utf8(header));
payload = enc.encodeToString(Utils.utf8(payload));
signature = enc.encodeToString(Utils.utf8(signature));
return String.format("%s.%s.%s", header, payload, signature);
}
} }

View File

@ -24,11 +24,12 @@ import org.apache.kafka.common.config.ConfigException;
import org.apache.kafka.common.config.SaslConfigs; import org.apache.kafka.common.config.SaslConfigs;
import org.apache.kafka.common.config.SslConfigs; import org.apache.kafka.common.config.SslConfigs;
import org.apache.kafka.common.config.types.Password; import org.apache.kafka.common.config.types.Password;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenRetriever; import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenRetrieverFactory;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidatorFactory;
import org.apache.kafka.common.security.oauthbearer.internals.secured.CloseableVerificationKeyResolver; import org.apache.kafka.common.security.oauthbearer.internals.secured.CloseableVerificationKeyResolver;
import org.apache.kafka.common.security.oauthbearer.internals.secured.DefaultJwtRetriever;
import org.apache.kafka.common.security.oauthbearer.internals.secured.DefaultJwtValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.JwtRetriever;
import org.apache.kafka.common.security.oauthbearer.internals.secured.JwtValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.VerificationKeyResolverFactory; import org.apache.kafka.common.security.oauthbearer.internals.secured.VerificationKeyResolverFactory;
import org.apache.kafka.common.utils.Exit; import org.apache.kafka.common.utils.Exit;
@ -139,16 +140,19 @@ public class OAuthCompatibilityTool {
{ {
// Client side... // Client side...
try (AccessTokenRetriever atr = AccessTokenRetrieverFactory.create(configs, jaasConfigs)) { try (JwtRetriever atr = new DefaultJwtRetriever(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, jaasConfigs)) {
atr.init(); atr.init();
AccessTokenValidator atv = AccessTokenValidatorFactory.create(configs);
System.out.println("PASSED 1/5: client configuration");
accessToken = atr.retrieve(); try (JwtValidator atv = new DefaultJwtValidator(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM)) {
System.out.println("PASSED 2/5: client JWT retrieval"); atv.init();
System.out.println("PASSED 1/5: client configuration");
atv.validate(accessToken); accessToken = atr.retrieve();
System.out.println("PASSED 3/5: client JWT validation"); System.out.println("PASSED 2/5: client JWT retrieval");
atv.validate(accessToken);
System.out.println("PASSED 3/5: client JWT validation");
}
} }
} }
@ -156,11 +160,14 @@ public class OAuthCompatibilityTool {
// Broker side... // Broker side...
try (CloseableVerificationKeyResolver vkr = VerificationKeyResolverFactory.create(configs, jaasConfigs)) { try (CloseableVerificationKeyResolver vkr = VerificationKeyResolverFactory.create(configs, jaasConfigs)) {
vkr.init(); vkr.init();
AccessTokenValidator atv = AccessTokenValidatorFactory.create(configs, vkr);
System.out.println("PASSED 4/5: broker configuration");
atv.validate(accessToken); try (JwtValidator atv = new DefaultJwtValidator(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, vkr)) {
System.out.println("PASSED 5/5: broker JWT validation"); atv.init();
System.out.println("PASSED 4/5: broker configuration");
atv.validate(accessToken);
System.out.println("PASSED 5/5: broker JWT validation");
}
} }
} }