KAFKA-18847: Refactor OAuth layer to improve reusability 1/N (#19622)
CI / build (push) Waiting to run Details

Rename `AccessTokenRetriever` and `AccessTokenValidator` to
`JwtRetriever` and `JwtValidator`, respectively. Also converting the
factory pattern classes `AccessTokenRetrieverFactory` and
`AccessTokenValidatorFactory` into delegate/wrapper classes
`DefaultJwtRetriever` and `DefaultJwtValidator`, respectively.

These are all internal changes, no configuration, user APIs, RPCs, etc.
were changed.

Reviewers: Manikumar Reddy <manikumar@confluent.io>, Ken Huang
 <s7133700@gmail.com>, Lianet Magrans <lmagrans@confluent.io>

---------

Co-authored-by: Ken Huang <s7133700@gmail.com>
This commit is contained in:
Kirk True 2025-05-13 09:35:20 -07:00 committed by GitHub
parent c16c240bd1
commit c60c83aaba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 598 additions and 476 deletions

View File

@ -24,12 +24,13 @@ import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
import org.apache.kafka.common.security.auth.SaslExtensions; import org.apache.kafka.common.security.auth.SaslExtensions;
import org.apache.kafka.common.security.auth.SaslExtensionsCallback; import org.apache.kafka.common.security.auth.SaslExtensionsCallback;
import org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerClientInitialResponse; import org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerClientInitialResponse;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenRetriever; import org.apache.kafka.common.security.oauthbearer.internals.secured.DefaultJwtRetriever;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenRetrieverFactory; import org.apache.kafka.common.security.oauthbearer.internals.secured.DefaultJwtValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidatorFactory;
import org.apache.kafka.common.security.oauthbearer.internals.secured.JaasOptionsUtils; import org.apache.kafka.common.security.oauthbearer.internals.secured.JaasOptionsUtils;
import org.apache.kafka.common.security.oauthbearer.internals.secured.JwtRetriever;
import org.apache.kafka.common.security.oauthbearer.internals.secured.JwtValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.ValidateException; import org.apache.kafka.common.security.oauthbearer.internals.secured.ValidateException;
import org.apache.kafka.common.utils.Utils;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
@ -179,55 +180,48 @@ public class OAuthBearerLoginCallbackHandler implements AuthenticateCallbackHand
private Map<String, Object> moduleOptions; private Map<String, Object> moduleOptions;
private AccessTokenRetriever accessTokenRetriever; private JwtRetriever jwtRetriever;
private AccessTokenValidator accessTokenValidator; private JwtValidator jwtValidator;
private boolean isInitialized = false;
@Override @Override
public void configure(Map<String, ?> configs, String saslMechanism, List<AppConfigurationEntry> jaasConfigEntries) { public void configure(Map<String, ?> configs, String saslMechanism, List<AppConfigurationEntry> jaasConfigEntries) {
moduleOptions = JaasOptionsUtils.getOptions(saslMechanism, jaasConfigEntries); Map<String, Object> moduleOptions = JaasOptionsUtils.getOptions(saslMechanism, jaasConfigEntries);
AccessTokenRetriever accessTokenRetriever = AccessTokenRetrieverFactory.create(configs, saslMechanism, moduleOptions); JwtRetriever jwtRetriever = new DefaultJwtRetriever(configs, saslMechanism, moduleOptions);
AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(configs, saslMechanism); JwtValidator jwtValidator = new DefaultJwtValidator(configs, saslMechanism);
init(accessTokenRetriever, accessTokenValidator); init(moduleOptions, jwtRetriever, jwtValidator);
}
public void init(AccessTokenRetriever accessTokenRetriever, AccessTokenValidator accessTokenValidator) {
this.accessTokenRetriever = accessTokenRetriever;
this.accessTokenValidator = accessTokenValidator;
try {
this.accessTokenRetriever.init();
} catch (IOException e) {
throw new KafkaException("The OAuth login configuration encountered an error when initializing the AccessTokenRetriever", e);
}
isInitialized = true;
} }
/* /*
* Package-visible for testing. * Package-visible for testing.
*/ */
void init(Map<String, Object> moduleOptions, JwtRetriever jwtRetriever, JwtValidator jwtValidator) {
this.moduleOptions = moduleOptions;
this.jwtRetriever = jwtRetriever;
this.jwtValidator = jwtValidator;
AccessTokenRetriever getAccessTokenRetriever() { try {
return accessTokenRetriever; this.jwtRetriever.init();
} catch (IOException e) {
throw new KafkaException("The OAuth login callback encountered an error when initializing the JwtRetriever", e);
}
try {
this.jwtValidator.init();
} catch (IOException e) {
throw new KafkaException("The OAuth login callback encountered an error when initializing the JwtValidator", e);
}
} }
@Override @Override
public void close() { public void close() {
if (accessTokenRetriever != null) { Utils.closeQuietly(jwtRetriever, "JWT retriever");
try { Utils.closeQuietly(jwtValidator, "JWT validator");
this.accessTokenRetriever.close();
} catch (IOException e) {
log.warn("The OAuth login configuration encountered an error when closing the AccessTokenRetriever", e);
}
}
} }
@Override @Override
public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
checkInitialized(); checkConfigured();
for (Callback callback : callbacks) { for (Callback callback : callbacks) {
if (callback instanceof OAuthBearerTokenCallback) { if (callback instanceof OAuthBearerTokenCallback) {
@ -241,11 +235,11 @@ public class OAuthBearerLoginCallbackHandler implements AuthenticateCallbackHand
} }
private void handleTokenCallback(OAuthBearerTokenCallback callback) throws IOException { private void handleTokenCallback(OAuthBearerTokenCallback callback) throws IOException {
checkInitialized(); checkConfigured();
String accessToken = accessTokenRetriever.retrieve(); String accessToken = jwtRetriever.retrieve();
try { try {
OAuthBearerToken token = accessTokenValidator.validate(accessToken); OAuthBearerToken token = jwtValidator.validate(accessToken);
callback.token(token); callback.token(token);
} catch (ValidateException e) { } catch (ValidateException e) {
log.warn(e.getMessage(), e); log.warn(e.getMessage(), e);
@ -254,7 +248,7 @@ public class OAuthBearerLoginCallbackHandler implements AuthenticateCallbackHand
} }
private void handleExtensionsCallback(SaslExtensionsCallback callback) { private void handleExtensionsCallback(SaslExtensionsCallback callback) {
checkInitialized(); checkConfigured();
Map<String, String> extensions = new HashMap<>(); Map<String, String> extensions = new HashMap<>();
@ -286,9 +280,9 @@ public class OAuthBearerLoginCallbackHandler implements AuthenticateCallbackHand
callback.extensions(saslExtensions); callback.extensions(saslExtensions);
} }
private void checkInitialized() { private void checkConfigured() {
if (!isInitialized) if (moduleOptions == null || jwtRetriever == null || jwtValidator == null)
throw new IllegalStateException(String.format("To use %s, first call the configure or init method", getClass().getSimpleName())); throw new IllegalStateException(String.format("To use %s, first call the configure method", getClass().getSimpleName()));
} }
} }

View File

@ -19,13 +19,14 @@ package org.apache.kafka.common.security.oauthbearer;
import org.apache.kafka.common.KafkaException; import org.apache.kafka.common.KafkaException;
import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler; import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidatorFactory;
import org.apache.kafka.common.security.oauthbearer.internals.secured.CloseableVerificationKeyResolver; import org.apache.kafka.common.security.oauthbearer.internals.secured.CloseableVerificationKeyResolver;
import org.apache.kafka.common.security.oauthbearer.internals.secured.DefaultJwtValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.JaasOptionsUtils; import org.apache.kafka.common.security.oauthbearer.internals.secured.JaasOptionsUtils;
import org.apache.kafka.common.security.oauthbearer.internals.secured.JwtValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.RefreshingHttpsJwksVerificationKeyResolver; import org.apache.kafka.common.security.oauthbearer.internals.secured.RefreshingHttpsJwksVerificationKeyResolver;
import org.apache.kafka.common.security.oauthbearer.internals.secured.ValidateException; import org.apache.kafka.common.security.oauthbearer.internals.secured.ValidateException;
import org.apache.kafka.common.security.oauthbearer.internals.secured.VerificationKeyResolverFactory; import org.apache.kafka.common.security.oauthbearer.internals.secured.VerificationKeyResolverFactory;
import org.apache.kafka.common.utils.Utils;
import org.jose4j.jws.JsonWebSignature; import org.jose4j.jws.JsonWebSignature;
import org.jose4j.jwx.JsonWebStructure; import org.jose4j.jwx.JsonWebStructure;
@ -119,9 +120,7 @@ public class OAuthBearerValidatorCallbackHandler implements AuthenticateCallback
private CloseableVerificationKeyResolver verificationKeyResolver; private CloseableVerificationKeyResolver verificationKeyResolver;
private AccessTokenValidator accessTokenValidator; private JwtValidator jwtValidator;
private boolean isInitialized = false;
@Override @Override
public void configure(Map<String, ?> configs, String saslMechanism, List<AppConfigurationEntry> jaasConfigEntries) { public void configure(Map<String, ?> configs, String saslMechanism, List<AppConfigurationEntry> jaasConfigEntries) {
@ -135,37 +134,39 @@ public class OAuthBearerValidatorCallbackHandler implements AuthenticateCallback
new RefCountingVerificationKeyResolver(VerificationKeyResolverFactory.create(configs, saslMechanism, moduleOptions))); new RefCountingVerificationKeyResolver(VerificationKeyResolverFactory.create(configs, saslMechanism, moduleOptions)));
} }
AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(configs, saslMechanism, verificationKeyResolver); JwtValidator jwtValidator = new DefaultJwtValidator(configs, saslMechanism, verificationKeyResolver);
init(verificationKeyResolver, accessTokenValidator); init(verificationKeyResolver, jwtValidator);
} }
public void init(CloseableVerificationKeyResolver verificationKeyResolver, AccessTokenValidator accessTokenValidator) { /*
* Package-visible for testing.
*/
void init(CloseableVerificationKeyResolver verificationKeyResolver, JwtValidator jwtValidator) {
this.verificationKeyResolver = verificationKeyResolver; this.verificationKeyResolver = verificationKeyResolver;
this.accessTokenValidator = accessTokenValidator; this.jwtValidator = jwtValidator;
try { try {
verificationKeyResolver.init(); verificationKeyResolver.init();
} catch (Exception e) { } catch (Exception e) {
throw new KafkaException("The OAuth validator configuration encountered an error when initializing the VerificationKeyResolver", e); throw new KafkaException("The OAuth validator callback encountered an error when initializing the VerificationKeyResolver", e);
} }
isInitialized = true; try {
jwtValidator.init();
} catch (IOException e) {
throw new KafkaException("The OAuth validator callback encountered an error when initializing the JwtValidator", e);
}
} }
@Override @Override
public void close() { public void close() {
if (verificationKeyResolver != null) { Utils.closeQuietly(jwtValidator, "JWT validator");
try { Utils.closeQuietly(verificationKeyResolver, "JWT verification key resolver");
verificationKeyResolver.close();
} catch (Exception e) {
log.error(e.getMessage(), e);
}
}
} }
@Override @Override
public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
checkInitialized(); checkConfigured();
for (Callback callback : callbacks) { for (Callback callback : callbacks) {
if (callback instanceof OAuthBearerValidatorCallback) { if (callback instanceof OAuthBearerValidatorCallback) {
@ -179,12 +180,12 @@ public class OAuthBearerValidatorCallbackHandler implements AuthenticateCallback
} }
private void handleValidatorCallback(OAuthBearerValidatorCallback callback) { private void handleValidatorCallback(OAuthBearerValidatorCallback callback) {
checkInitialized(); checkConfigured();
OAuthBearerToken token; OAuthBearerToken token;
try { try {
token = accessTokenValidator.validate(callback.tokenValue()); token = jwtValidator.validate(callback.tokenValue());
callback.token(token); callback.token(token);
} catch (ValidateException e) { } catch (ValidateException e) {
log.warn(e.getMessage(), e); log.warn(e.getMessage(), e);
@ -193,14 +194,14 @@ public class OAuthBearerValidatorCallbackHandler implements AuthenticateCallback
} }
private void handleExtensionsValidatorCallback(OAuthBearerExtensionsValidatorCallback extensionsValidatorCallback) { private void handleExtensionsValidatorCallback(OAuthBearerExtensionsValidatorCallback extensionsValidatorCallback) {
checkInitialized(); checkConfigured();
extensionsValidatorCallback.inputExtensions().map().forEach((extensionName, v) -> extensionsValidatorCallback.valid(extensionName)); extensionsValidatorCallback.inputExtensions().map().forEach((extensionName, v) -> extensionsValidatorCallback.valid(extensionName));
} }
private void checkInitialized() { private void checkConfigured() {
if (!isInitialized) if (verificationKeyResolver == null || jwtValidator == null)
throw new IllegalStateException(String.format("To use %s, first call the configure or init method", getClass().getSimpleName())); throw new IllegalStateException(String.format("To use %s, first call the configure method", getClass().getSimpleName()));
} }
/** /**

View File

@ -1,73 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.common.security.oauthbearer.internals.secured;
import org.jose4j.keys.resolvers.VerificationKeyResolver;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_CLOCK_SKEW_SECONDS;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_EXPECTED_AUDIENCE;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_EXPECTED_ISSUER;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_SCOPE_CLAIM_NAME;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_SUB_CLAIM_NAME;
public class AccessTokenValidatorFactory {
public static AccessTokenValidator create(Map<String, ?> configs) {
return create(configs, (String) null);
}
public static AccessTokenValidator create(Map<String, ?> configs, String saslMechanism) {
ConfigurationUtils cu = new ConfigurationUtils(configs, saslMechanism);
String scopeClaimName = cu.get(SASL_OAUTHBEARER_SCOPE_CLAIM_NAME);
String subClaimName = cu.get(SASL_OAUTHBEARER_SUB_CLAIM_NAME);
return new LoginAccessTokenValidator(scopeClaimName, subClaimName);
}
public static AccessTokenValidator create(Map<String, ?> configs,
VerificationKeyResolver verificationKeyResolver) {
return create(configs, null, verificationKeyResolver);
}
public static AccessTokenValidator create(Map<String, ?> configs,
String saslMechanism,
VerificationKeyResolver verificationKeyResolver) {
ConfigurationUtils cu = new ConfigurationUtils(configs, saslMechanism);
Set<String> expectedAudiences = null;
List<String> l = cu.get(SASL_OAUTHBEARER_EXPECTED_AUDIENCE);
if (l != null)
expectedAudiences = Set.copyOf(l);
Integer clockSkew = cu.validateInteger(SASL_OAUTHBEARER_CLOCK_SKEW_SECONDS, false);
String expectedIssuer = cu.validateString(SASL_OAUTHBEARER_EXPECTED_ISSUER, false);
String scopeClaimName = cu.validateString(SASL_OAUTHBEARER_SCOPE_CLAIM_NAME);
String subClaimName = cu.validateString(SASL_OAUTHBEARER_SUB_CLAIM_NAME);
return new ValidatorAccessTokenValidator(clockSkew,
expectedAudiences,
expectedIssuer,
verificationKeyResolver,
scopeClaimName,
subClaimName);
}
}

View File

@ -38,7 +38,7 @@ import java.util.Set;
import static org.jose4j.jwa.AlgorithmConstraints.DISALLOW_NONE; import static org.jose4j.jwa.AlgorithmConstraints.DISALLOW_NONE;
/** /**
* ValidatorAccessTokenValidator is an implementation of {@link AccessTokenValidator} that is used * {@code BrokerJwtValidator} is an implementation of {@link JwtValidator} that is used
* by the broker to perform more extensive validation of the JWT access token that is received * by the broker to perform more extensive validation of the JWT access token that is received
* from the client, but ultimately from posting the client credentials to the OAuth/OIDC provider's * from the client, but ultimately from posting the client credentials to the OAuth/OIDC provider's
* token endpoint. * token endpoint.
@ -62,9 +62,9 @@ import static org.jose4j.jwa.AlgorithmConstraints.DISALLOW_NONE;
* </ol> * </ol>
*/ */
public class ValidatorAccessTokenValidator implements AccessTokenValidator { public class BrokerJwtValidator implements JwtValidator {
private static final Logger log = LoggerFactory.getLogger(ValidatorAccessTokenValidator.class); private static final Logger log = LoggerFactory.getLogger(BrokerJwtValidator.class);
private final JwtConsumer jwtConsumer; private final JwtConsumer jwtConsumer;
@ -73,7 +73,7 @@ public class ValidatorAccessTokenValidator implements AccessTokenValidator {
private final String subClaimName; private final String subClaimName;
/** /**
* Creates a new ValidatorAccessTokenValidator that will be used by the broker for more * Creates a new {@code BrokerJwtValidator} that will be used by the broker for more
* thorough validation of the JWT. * thorough validation of the JWT.
* *
* @param clockSkew The optional value (in seconds) to allow for differences * @param clockSkew The optional value (in seconds) to allow for differences
@ -112,7 +112,7 @@ public class ValidatorAccessTokenValidator implements AccessTokenValidator {
* @see VerificationKeyResolver * @see VerificationKeyResolver
*/ */
public ValidatorAccessTokenValidator(Integer clockSkew, public BrokerJwtValidator(Integer clockSkew,
Set<String> expectedAudiences, Set<String> expectedAudiences,
String expectedIssuer, String expectedIssuer,
VerificationKeyResolver verificationKeyResolver, VerificationKeyResolver verificationKeyResolver,

View File

@ -33,7 +33,7 @@ import static org.apache.kafka.common.config.SaslConfigs.DEFAULT_SASL_OAUTHBEARE
import static org.apache.kafka.common.config.SaslConfigs.DEFAULT_SASL_OAUTHBEARER_SUB_CLAIM_NAME; import static org.apache.kafka.common.config.SaslConfigs.DEFAULT_SASL_OAUTHBEARER_SUB_CLAIM_NAME;
/** /**
* LoginAccessTokenValidator is an implementation of {@link AccessTokenValidator} that is used * {@code ClientJwtValidator} is an implementation of {@link JwtValidator} that is used
* by the client to perform some rudimentary validation of the JWT access token that is received * by the client to perform some rudimentary validation of the JWT access token that is received
* as part of the response from posting the client credentials to the OAuth/OIDC provider's * as part of the response from posting the client credentials to the OAuth/OIDC provider's
* token endpoint. * token endpoint.
@ -46,13 +46,13 @@ import static org.apache.kafka.common.config.SaslConfigs.DEFAULT_SASL_OAUTHBEARE
* <a href="https://tools.ietf.org/html/rfc6750#section-2.1">RFC 6750 Section 2.1</a> * <a href="https://tools.ietf.org/html/rfc6750#section-2.1">RFC 6750 Section 2.1</a>
* </li> * </li>
* <li>Basic conversion of the token into an in-memory map</li> * <li>Basic conversion of the token into an in-memory map</li>
* <li>Presence of scope, <code>exp</code>, subject, and <code>iat</code> claims</li> * <li>Presence of <code>scope</code>, <code>exp</code>, <code>subject</code>, and <code>iat</code> claims</li>
* </ol> * </ol>
*/ */
public class LoginAccessTokenValidator implements AccessTokenValidator { public class ClientJwtValidator implements JwtValidator {
private static final Logger log = LoggerFactory.getLogger(LoginAccessTokenValidator.class); private static final Logger log = LoggerFactory.getLogger(ClientJwtValidator.class);
public static final String EXPIRATION_CLAIM_NAME = "exp"; public static final String EXPIRATION_CLAIM_NAME = "exp";
@ -63,14 +63,14 @@ public class LoginAccessTokenValidator implements AccessTokenValidator {
private final String subClaimName; private final String subClaimName;
/** /**
* Creates a new LoginAccessTokenValidator that will be used by the client for lightweight * Creates a new {@code ClientJwtValidator} that will be used by the client for lightweight
* validation of the JWT. * validation of the JWT.
* *
* @param scopeClaimName Name of the scope claim to use; must be non-<code>null</code> * @param scopeClaimName Name of the scope claim to use; must be non-<code>null</code>
* @param subClaimName Name of the subject claim to use; must be non-<code>null</code> * @param subClaimName Name of the subject claim to use; must be non-<code>null</code>
*/ */
public LoginAccessTokenValidator(String scopeClaimName, String subClaimName) { public ClientJwtValidator(String scopeClaimName, String subClaimName) {
this.scopeClaimName = ClaimValidationUtils.validateClaimNameOverride(DEFAULT_SASL_OAUTHBEARER_SCOPE_CLAIM_NAME, scopeClaimName); this.scopeClaimName = ClaimValidationUtils.validateClaimNameOverride(DEFAULT_SASL_OAUTHBEARER_SCOPE_CLAIM_NAME, scopeClaimName);
this.subClaimName = ClaimValidationUtils.validateClaimNameOverride(DEFAULT_SASL_OAUTHBEARER_SUB_CLAIM_NAME, subClaimName); this.subClaimName = ClaimValidationUtils.validateClaimNameOverride(DEFAULT_SASL_OAUTHBEARER_SUB_CLAIM_NAME, subClaimName);
} }

View File

@ -18,7 +18,9 @@
package org.apache.kafka.common.security.oauthbearer.internals.secured; package org.apache.kafka.common.security.oauthbearer.internals.secured;
import org.apache.kafka.common.config.SaslConfigs; import org.apache.kafka.common.config.SaslConfigs;
import org.apache.kafka.common.utils.Utils;
import java.io.IOException;
import java.net.URL; import java.net.URL;
import java.util.Locale; import java.util.Locale;
import java.util.Map; import java.util.Map;
@ -36,32 +38,33 @@ import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallb
import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler.CLIENT_SECRET_CONFIG; import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler.CLIENT_SECRET_CONFIG;
import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler.SCOPE_CONFIG; import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler.SCOPE_CONFIG;
public class AccessTokenRetrieverFactory { /**
* {@code DefaultJwtRetriever} instantiates and delegates {@link JwtRetriever} API calls to an embedded implementation
/** * based on configuration. If {@link SaslConfigs#SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL} is configured with a
* Create an {@link AccessTokenRetriever} from the given SASL and JAAS configuration. * {@code file}-based URL, a {@link FileJwtRetriever} is created and the JWT is expected be contained in the file
* * specified. Otherwise, it's assumed to be an HTTP/HTTPS-based URL, so an {@link HttpJwtRetriever} is created.
* <b>Note</b>: the returned <code>AccessTokenRetriever</code> is <em>not</em> initialized
* here and must be done by the caller prior to use.
*
* @param configs SASL configuration
* @param jaasConfig JAAS configuration
*
* @return Non-<code>null</code> {@link AccessTokenRetriever}
*/ */
public class DefaultJwtRetriever implements JwtRetriever {
public static AccessTokenRetriever create(Map<String, ?> configs, Map<String, Object> jaasConfig) { private final Map<String, ?> configs;
return create(configs, null, jaasConfig); private final String saslMechanism;
private final Map<String, Object> jaasConfig;
private JwtRetriever delegate;
public DefaultJwtRetriever(Map<String, ?> configs, String saslMechanism, Map<String, Object> jaasConfig) {
this.configs = configs;
this.saslMechanism = saslMechanism;
this.jaasConfig = jaasConfig;
} }
public static AccessTokenRetriever create(Map<String, ?> configs, @Override
String saslMechanism, public void init() throws IOException {
Map<String, Object> jaasConfig) {
ConfigurationUtils cu = new ConfigurationUtils(configs, saslMechanism); ConfigurationUtils cu = new ConfigurationUtils(configs, saslMechanism);
URL tokenEndpointUrl = cu.validateUrl(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL); URL tokenEndpointUrl = cu.validateUrl(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL);
if (tokenEndpointUrl.getProtocol().toLowerCase(Locale.ROOT).equals("file")) { if (tokenEndpointUrl.getProtocol().toLowerCase(Locale.ROOT).equals("file")) {
return new FileTokenRetriever(cu.validateFile(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL)); delegate = new FileJwtRetriever(cu.validateFile(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL));
} else { } else {
JaasOptionsUtils jou = new JaasOptionsUtils(jaasConfig); JaasOptionsUtils jou = new JaasOptionsUtils(jaasConfig);
String clientId = jou.validateString(CLIENT_ID_CONFIG); String clientId = jou.validateString(CLIENT_ID_CONFIG);
@ -75,7 +78,7 @@ public class AccessTokenRetrieverFactory {
boolean urlencodeHeader = validateUrlencodeHeader(cu); boolean urlencodeHeader = validateUrlencodeHeader(cu);
return new HttpAccessTokenRetriever(clientId, delegate = new HttpJwtRetriever(clientId,
clientSecret, clientSecret,
scope, scope,
sslSocketFactory, sslSocketFactory,
@ -86,6 +89,21 @@ public class AccessTokenRetrieverFactory {
cu.validateInteger(SASL_LOGIN_READ_TIMEOUT_MS, false), cu.validateInteger(SASL_LOGIN_READ_TIMEOUT_MS, false),
urlencodeHeader); urlencodeHeader);
} }
delegate.init();
}
@Override
public String retrieve() throws IOException {
if (delegate == null)
throw new IllegalStateException("JWT retriever delegate is null; please call init() first");
return delegate.retrieve();
}
@Override
public void close() throws IOException {
Utils.closeQuietly(delegate, "JWT retriever delegate");
} }
/** /**
@ -96,10 +114,10 @@ public class AccessTokenRetrieverFactory {
* <p/> * <p/>
* *
* This utility method ensures that we have a non-{@code null} value to use in the * This utility method ensures that we have a non-{@code null} value to use in the
* {@link HttpAccessTokenRetriever} constructor. * {@link HttpJwtRetriever} constructor.
*/ */
static boolean validateUrlencodeHeader(ConfigurationUtils configurationUtils) { static boolean validateUrlencodeHeader(ConfigurationUtils configurationUtils) {
Boolean urlencodeHeader = configurationUtils.validateBoolean(SASL_OAUTHBEARER_HEADER_URLENCODE, false); Boolean urlencodeHeader = configurationUtils.get(SASL_OAUTHBEARER_HEADER_URLENCODE);
if (urlencodeHeader != null) if (urlencodeHeader != null)
return urlencodeHeader; return urlencodeHeader;
@ -107,4 +125,7 @@ public class AccessTokenRetrieverFactory {
return DEFAULT_SASL_OAUTHBEARER_HEADER_URLENCODE; return DEFAULT_SASL_OAUTHBEARER_HEADER_URLENCODE;
} }
JwtRetriever delegate() {
return delegate;
}
} }

View File

@ -0,0 +1,108 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.common.security.oauthbearer.internals.secured;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
import org.apache.kafka.common.utils.Utils;
import org.jose4j.keys.resolvers.VerificationKeyResolver;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_CLOCK_SKEW_SECONDS;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_EXPECTED_AUDIENCE;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_EXPECTED_ISSUER;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_SCOPE_CLAIM_NAME;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_SUB_CLAIM_NAME;
/**
* This {@link JwtValidator} uses the delegation approach, instantiating and delegating calls to a
* more concrete implementation. The underlying implementation is determined by the presence/absence
* of the {@link VerificationKeyResolver}: if it's present, a {@link BrokerJwtValidator} is
* created, otherwise a {@link ClientJwtValidator} is created.
*/
public class DefaultJwtValidator implements JwtValidator {
private final Map<String, ?> configs;
private final String saslMechanism;
private final Optional<VerificationKeyResolver> verificationKeyResolver;
private JwtValidator delegate;
public DefaultJwtValidator(Map<String, ?> configs, String saslMechanism) {
this.configs = configs;
this.saslMechanism = saslMechanism;
this.verificationKeyResolver = Optional.empty();
}
public DefaultJwtValidator(Map<String, ?> configs,
String saslMechanism,
VerificationKeyResolver verificationKeyResolver) {
this.configs = configs;
this.saslMechanism = saslMechanism;
this.verificationKeyResolver = Optional.of(verificationKeyResolver);
}
@Override
public void init() throws IOException {
ConfigurationUtils cu = new ConfigurationUtils(configs, saslMechanism);
if (verificationKeyResolver.isPresent()) {
List<String> expectedAudiencesList = cu.get(SASL_OAUTHBEARER_EXPECTED_AUDIENCE);
Set<String> expectedAudiences = expectedAudiencesList != null ? Set.copyOf(expectedAudiencesList) : null;
Integer clockSkew = cu.validateInteger(SASL_OAUTHBEARER_CLOCK_SKEW_SECONDS, false);
String expectedIssuer = cu.validateString(SASL_OAUTHBEARER_EXPECTED_ISSUER, false);
String scopeClaimName = cu.validateString(SASL_OAUTHBEARER_SCOPE_CLAIM_NAME);
String subClaimName = cu.validateString(SASL_OAUTHBEARER_SUB_CLAIM_NAME);
delegate = new BrokerJwtValidator(clockSkew,
expectedAudiences,
expectedIssuer,
verificationKeyResolver.get(),
scopeClaimName,
subClaimName);
} else {
String scopeClaimName = cu.get(SASL_OAUTHBEARER_SCOPE_CLAIM_NAME);
String subClaimName = cu.get(SASL_OAUTHBEARER_SUB_CLAIM_NAME);
delegate = new ClientJwtValidator(scopeClaimName, subClaimName);
}
delegate.init();
}
@Override
public OAuthBearerToken validate(String accessToken) throws ValidateException {
if (delegate == null)
throw new IllegalStateException("JWT validator delegate is null; please call init() first");
return delegate.validate(accessToken);
}
@Override
public void close() throws IOException {
Utils.closeQuietly(delegate, "JWT validator delegate");
}
JwtValidator delegate() {
return delegate;
}
}

View File

@ -23,19 +23,19 @@ import java.io.IOException;
import java.nio.file.Path; import java.nio.file.Path;
/** /**
* <code>FileTokenRetriever</code> is an {@link AccessTokenRetriever} that will load the contents, * <code>FileJwtRetriever</code> is an {@link JwtRetriever} that will load the contents
* interpreting them as a JWT access key in the serialized form. * of a file, interpreting them as a JWT access key in the serialized form.
* *
* @see AccessTokenRetriever * @see JwtRetriever
*/ */
public class FileTokenRetriever implements AccessTokenRetriever { public class FileJwtRetriever implements JwtRetriever {
private final Path accessTokenFile; private final Path accessTokenFile;
private String accessToken; private String accessToken;
public FileTokenRetriever(Path accessTokenFile) { public FileJwtRetriever(Path accessTokenFile) {
this.accessTokenFile = accessTokenFile; this.accessTokenFile = accessTokenFile;
} }

View File

@ -49,22 +49,14 @@ import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLSocketFactory; import javax.net.ssl.SSLSocketFactory;
/** /**
* <code>HttpAccessTokenRetriever</code> is an {@link AccessTokenRetriever} that will * <code>HttpJwtRetriever</code> is a {@link JwtRetriever} that will communicate with an OAuth/OIDC
* communicate with an OAuth/OIDC provider directly via HTTP to post client credentials * provider directly via HTTP to post client credentials
* ({@link OAuthBearerLoginCallbackHandler#CLIENT_ID_CONFIG}/{@link OAuthBearerLoginCallbackHandler#CLIENT_SECRET_CONFIG}) * ({@link OAuthBearerLoginCallbackHandler#CLIENT_ID_CONFIG}/{@link OAuthBearerLoginCallbackHandler#CLIENT_SECRET_CONFIG})
* to a publicized token endpoint URL * to a publicized token endpoint URL ({@link SaslConfigs#SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL}).
* ({@link SaslConfigs#SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL}).
*
* @see AccessTokenRetriever
* @see OAuthBearerLoginCallbackHandler#CLIENT_ID_CONFIG
* @see OAuthBearerLoginCallbackHandler#CLIENT_SECRET_CONFIG
* @see OAuthBearerLoginCallbackHandler#SCOPE_CONFIG
* @see SaslConfigs#SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL
*/ */
public class HttpJwtRetriever implements JwtRetriever {
public class HttpAccessTokenRetriever implements AccessTokenRetriever { private static final Logger log = LoggerFactory.getLogger(HttpJwtRetriever.class);
private static final Logger log = LoggerFactory.getLogger(HttpAccessTokenRetriever.class);
private static final Set<Integer> UNRETRYABLE_HTTP_CODES; private static final Set<Integer> UNRETRYABLE_HTTP_CODES;
@ -117,7 +109,7 @@ public class HttpAccessTokenRetriever implements AccessTokenRetriever {
private final boolean urlencodeHeader; private final boolean urlencodeHeader;
public HttpAccessTokenRetriever(String clientId, public HttpJwtRetriever(String clientId,
String clientSecret, String clientSecret,
String scope, String scope,
SSLSocketFactory sslSocketFactory, SSLSocketFactory sslSocketFactory,

View File

@ -22,8 +22,8 @@ import java.io.IOException;
public interface Initable { public interface Initable {
/** /**
* Lifecycle method to perform any one-time initialization of the retriever. This must * Lifecycle method to perform any one-time initialization of a given resource. This must
* be performed by the caller to ensure the correct state before methods are invoked. * be invoked by the caller to ensure the correct state before methods are invoked.
* *
* @throws IOException Thrown on errors related to IO during initialization * @throws IOException Thrown on errors related to IO during initialization
*/ */
@ -31,5 +31,4 @@ public interface Initable {
default void init() throws IOException { default void init() throws IOException {
// This method left intentionally blank. // This method left intentionally blank.
} }
} }

View File

@ -21,20 +21,20 @@ import java.io.Closeable;
import java.io.IOException; import java.io.IOException;
/** /**
* An <code>AccessTokenRetriever</code> is the internal API by which the login module will * A <code>JwtRetriever</code> is the internal API by which the login module will
* retrieve an access token for use in authorization by the broker. The implementation may * retrieve an access token for use in authorization by the broker. The implementation may
* involve authentication to a remote system, or it can be as simple as loading the contents * involve authentication to a remote system, or it can be as simple as loading the contents
* of a file or configuration setting. * of a file or configuration setting.
* *
* <i>Retrieval</i> is a separate concern from <i>validation</i>, so it isn't necessary for * <i>Retrieval</i> is a separate concern from <i>validation</i>, so it isn't necessary for
* the <code>AccessTokenRetriever</code> implementation to validate the integrity of the JWT * the <code>JwtRetriever</code> implementation to validate the integrity of the JWT
* access token. * access token.
* *
* @see HttpAccessTokenRetriever * @see HttpJwtRetriever
* @see FileTokenRetriever * @see FileJwtRetriever
*/ */
public interface AccessTokenRetriever extends Initable, Closeable { public interface JwtRetriever extends Initable, Closeable {
/** /**
* Retrieves a JWT access token in its serialized three-part form. The implementation * Retrieves a JWT access token in its serialized three-part form. The implementation

View File

@ -19,8 +19,11 @@ package org.apache.kafka.common.security.oauthbearer.internals.secured;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken; import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
import java.io.Closeable;
import java.io.IOException;
/** /**
* An instance of <code>AccessTokenValidator</code> acts as a function object that, given an access * An instance of <code>JwtValidator</code> acts as a function object that, given an access
* token in base-64 encoded JWT format, can parse the data, perform validation, and construct an * token in base-64 encoded JWT format, can parse the data, perform validation, and construct an
* {@link OAuthBearerToken} for use by the caller. * {@link OAuthBearerToken} for use by the caller.
* *
@ -40,13 +43,12 @@ import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
* <li><a href="https://datatracker.ietf.org/doc/html/draft-ietf-oauth-access-token-jwt">RFC 6750, Section 2.1</a></li> * <li><a href="https://datatracker.ietf.org/doc/html/draft-ietf-oauth-access-token-jwt">RFC 6750, Section 2.1</a></li>
* </ul> * </ul>
* *
* @see LoginAccessTokenValidator A basic AccessTokenValidator used by client-side login * @see ClientJwtValidator A basic JwtValidator used by client-side login authentication
* authentication * @see BrokerJwtValidator A more robust JwtValidator that is used on the broker to validate the token's
* @see ValidatorAccessTokenValidator A more robust AccessTokenValidator that is used on the broker * contents and verify the signature
* to validate the token's contents and verify the signature
*/ */
public interface AccessTokenValidator { public interface JwtValidator extends Initable, Closeable {
/** /**
* Accepts an OAuth JWT access token in base-64 encoded format, validates, and returns an * Accepts an OAuth JWT access token in base-64 encoded format, validates, and returns an
@ -61,4 +63,10 @@ public interface AccessTokenValidator {
OAuthBearerToken validate(String accessToken) throws ValidateException; OAuthBearerToken validate(String accessToken) throws ValidateException;
/**
* Closes any resources that were initialized by {@link #init()}.
*/
default void close() throws IOException {
// Do nothing...
}
} }

View File

@ -49,12 +49,12 @@ import java.util.concurrent.locks.ReentrantReadWriteLock;
* This instance is created and provided to the * This instance is created and provided to the
* {@link org.jose4j.keys.resolvers.HttpsJwksVerificationKeyResolver} that is used when using * {@link org.jose4j.keys.resolvers.HttpsJwksVerificationKeyResolver} that is used when using
* an HTTP-/HTTPS-based {@link org.jose4j.keys.resolvers.VerificationKeyResolver}, which is then * an HTTP-/HTTPS-based {@link org.jose4j.keys.resolvers.VerificationKeyResolver}, which is then
* provided to the {@link ValidatorAccessTokenValidator} to use in validating the signature of * provided to the {@link BrokerJwtValidator} to use in validating the signature of
* a JWT. * a JWT.
* *
* @see org.jose4j.keys.resolvers.HttpsJwksVerificationKeyResolver * @see org.jose4j.keys.resolvers.HttpsJwksVerificationKeyResolver
* @see org.jose4j.keys.resolvers.VerificationKeyResolver * @see org.jose4j.keys.resolvers.VerificationKeyResolver
* @see ValidatorAccessTokenValidator * @see BrokerJwtValidator
*/ */
public final class RefreshingHttpsJwks implements Initable, Closeable { public final class RefreshingHttpsJwks implements Initable, Closeable {

View File

@ -27,7 +27,7 @@ import javax.security.auth.callback.Callback;
* processing of a {@link javax.security.auth.callback.CallbackHandler#handle(Callback[])}. * processing of a {@link javax.security.auth.callback.CallbackHandler#handle(Callback[])}.
* This error, however, is not thrown from that method directly. * This error, however, is not thrown from that method directly.
* *
* @see AccessTokenValidator#validate(String) * @see JwtValidator#validate(String)
*/ */
public class ValidateException extends KafkaException { public class ValidateException extends KafkaException {

View File

@ -37,7 +37,7 @@ import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_JWKS_E
public class VerificationKeyResolverFactory { public class VerificationKeyResolverFactory {
/** /**
* Create an {@link AccessTokenRetriever} from the given * Create a {@link JwtRetriever} from the given
* {@link org.apache.kafka.common.config.SaslConfigs}. * {@link org.apache.kafka.common.config.SaslConfigs}.
* *
* <b>Note</b>: the returned <code>CloseableVerificationKeyResolver</code> is not * <b>Note</b>: the returned <code>CloseableVerificationKeyResolver</code> is not

View File

@ -21,13 +21,12 @@ import org.apache.kafka.common.config.ConfigException;
import org.apache.kafka.common.security.auth.SaslExtensionsCallback; import org.apache.kafka.common.security.auth.SaslExtensionsCallback;
import org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerClientInitialResponse; import org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerClientInitialResponse;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenBuilder; import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenBuilder;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenRetriever; import org.apache.kafka.common.security.oauthbearer.internals.secured.DefaultJwtRetriever;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidator; import org.apache.kafka.common.security.oauthbearer.internals.secured.DefaultJwtValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidatorFactory; import org.apache.kafka.common.security.oauthbearer.internals.secured.FileJwtRetriever;
import org.apache.kafka.common.security.oauthbearer.internals.secured.FileTokenRetriever; import org.apache.kafka.common.security.oauthbearer.internals.secured.JwtRetriever;
import org.apache.kafka.common.security.oauthbearer.internals.secured.HttpAccessTokenRetriever; import org.apache.kafka.common.security.oauthbearer.internals.secured.JwtValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.OAuthBearerTest; import org.apache.kafka.common.security.oauthbearer.internals.secured.OAuthBearerTest;
import org.apache.kafka.common.utils.Utils;
import org.jose4j.jws.AlgorithmIdentifiers; import org.jose4j.jws.AlgorithmIdentifiers;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
@ -35,9 +34,7 @@ import org.junit.jupiter.api.Test;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.util.Base64;
import java.util.Calendar; import java.util.Calendar;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.TimeZone; import java.util.TimeZone;
@ -50,7 +47,6 @@ import static org.apache.kafka.common.config.internals.BrokerSecurityConfigs.ALL
import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler.CLIENT_ID_CONFIG; import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler.CLIENT_ID_CONFIG;
import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler.CLIENT_SECRET_CONFIG; import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler.CLIENT_SECRET_CONFIG;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
@ -58,6 +54,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assertions.fail;
public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest { public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest {
@AfterEach @AfterEach
public void tearDown() throws Exception { public void tearDown() throws Exception {
System.clearProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG); System.clearProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG);
@ -70,9 +67,10 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest {
.jwk(createRsaJwk()) .jwk(createRsaJwk())
.alg(AlgorithmIdentifiers.RSA_USING_SHA256); .alg(AlgorithmIdentifiers.RSA_USING_SHA256);
String accessToken = builder.build(); String accessToken = builder.build();
AccessTokenRetriever accessTokenRetriever = () -> accessToken; JwtRetriever jwtRetriever = () -> accessToken;
JwtValidator jwtValidator = createJwtValidator(configs);
OAuthBearerLoginCallbackHandler handler = createHandler(accessTokenRetriever, configs); OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
handler.init(Map.of(), jwtRetriever, jwtValidator);
try { try {
OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback(); OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback();
@ -91,7 +89,6 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest {
@Test @Test
public void testHandleSaslExtensionsCallback() throws Exception { public void testHandleSaslExtensionsCallback() throws Exception {
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, "http://www.example.com"); Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, "http://www.example.com");
System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, "http://www.example.com"); System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, "http://www.example.com");
Map<String, Object> jaasConfig = new HashMap<>(); Map<String, Object> jaasConfig = new HashMap<>();
@ -100,7 +97,11 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest {
jaasConfig.put("extension_foo", "1"); jaasConfig.put("extension_foo", "1");
jaasConfig.put("extension_bar", 2); jaasConfig.put("extension_bar", 2);
jaasConfig.put("EXTENSION_baz", "3"); jaasConfig.put("EXTENSION_baz", "3");
configureHandler(handler, configs, jaasConfig);
JwtRetriever jwtRetriever = createJwtRetriever(configs, jaasConfig);
JwtValidator jwtValidator = createJwtValidator(configs);
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
handler.init(jaasConfig, jwtRetriever, jwtValidator);
try { try {
SaslExtensionsCallback callback = new SaslExtensionsCallback(); SaslExtensionsCallback callback = new SaslExtensionsCallback();
@ -121,14 +122,17 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest {
public void testHandleSaslExtensionsCallbackWithInvalidExtension() { public void testHandleSaslExtensionsCallbackWithInvalidExtension() {
String illegalKey = "extension_" + OAuthBearerClientInitialResponse.AUTH_KEY; String illegalKey = "extension_" + OAuthBearerClientInitialResponse.AUTH_KEY;
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, "http://www.example.com"); Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, "http://www.example.com");
System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, "http://www.example.com"); System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, "http://www.example.com");
Map<String, Object> jaasConfig = new HashMap<>(); Map<String, Object> jaasConfig = new HashMap<>();
jaasConfig.put(CLIENT_ID_CONFIG, "an ID"); jaasConfig.put(CLIENT_ID_CONFIG, "an ID");
jaasConfig.put(CLIENT_SECRET_CONFIG, "a secret"); jaasConfig.put(CLIENT_SECRET_CONFIG, "a secret");
jaasConfig.put(illegalKey, "this key isn't allowed per OAuthBearerClientInitialResponse.validateExtensions"); jaasConfig.put(illegalKey, "this key isn't allowed per OAuthBearerClientInitialResponse.validateExtensions");
configureHandler(handler, configs, jaasConfig);
JwtRetriever jwtRetriever = createJwtRetriever(configs, jaasConfig);
JwtValidator jwtValidator = createJwtValidator(configs);
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
handler.init(jaasConfig, jwtRetriever, jwtValidator);
try { try {
SaslExtensionsCallback callback = new SaslExtensionsCallback(); SaslExtensionsCallback callback = new SaslExtensionsCallback();
@ -143,10 +147,10 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest {
@Test @Test
public void testInvalidCallbackGeneratesUnsupportedCallbackException() { public void testInvalidCallbackGeneratesUnsupportedCallbackException() {
Map<String, ?> configs = getSaslConfigs(); Map<String, ?> configs = getSaslConfigs();
JwtRetriever jwtRetriever = () -> "test";
JwtValidator jwtValidator = createJwtValidator(configs);
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler(); OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
AccessTokenRetriever accessTokenRetriever = () -> "foo"; handler.init(Map.of(), jwtRetriever, jwtValidator);
AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(configs);
handler.init(accessTokenRetriever, accessTokenValidator);
try { try {
Callback unsupportedCallback = new Callback() { }; Callback unsupportedCallback = new Callback() { };
@ -166,11 +170,13 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest {
@Test @Test
public void testMissingAccessToken() { public void testMissingAccessToken() {
AccessTokenRetriever accessTokenRetriever = () -> { Map<String, ?> configs = getSaslConfigs();
JwtRetriever jwtRetriever = () -> {
throw new IOException("The token endpoint response access_token value must be non-null"); throw new IOException("The token endpoint response access_token value must be non-null");
}; };
Map<String, ?> configs = getSaslConfigs(); JwtValidator jwtValidator = createJwtValidator(configs);
OAuthBearerLoginCallbackHandler handler = createHandler(accessTokenRetriever, configs); OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
handler.init(Map.of(), jwtRetriever, jwtValidator);
try { try {
OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback(); OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback();
@ -196,7 +202,11 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest {
File accessTokenFile = createTempFile(tmpDir, "access-token-", ".json", withNewline); File accessTokenFile = createTempFile(tmpDir, "access-token-", ".json", withNewline);
Map<String, ?> configs = getSaslConfigs(); Map<String, ?> configs = getSaslConfigs();
OAuthBearerLoginCallbackHandler handler = createHandler(new FileTokenRetriever(accessTokenFile.toPath()), configs); JwtRetriever jwtRetriever = new FileJwtRetriever(accessTokenFile.toPath());
JwtValidator jwtValidator = createJwtValidator(configs);
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
handler.init(Map.of(), jwtRetriever, jwtValidator);
OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback(); OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback();
try { try {
handler.handle(new Callback[]{callback}); handler.handle(new Callback[]{callback});
@ -211,39 +221,15 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest {
@Test @Test
public void testNotConfigured() { public void testNotConfigured() {
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler(); OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
assertThrowsWithMessage(IllegalStateException.class, () -> handler.handle(new Callback[] {}), "first call the configure or init method"); assertThrowsWithMessage(IllegalStateException.class, () -> handler.handle(new Callback[] {}), "first call the configure method");
}
@Test
public void testConfigureWithAccessTokenFile() throws Exception {
String expected = "{}";
File tmpDir = createTempDir("access-token");
File accessTokenFile = createTempFile(tmpDir, "access-token-", ".json", expected);
System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, accessTokenFile.toURI().toString());
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, accessTokenFile.toURI().toString());
Map<String, Object> jaasConfigs = Collections.emptyMap();
configureHandler(handler, configs, jaasConfigs);
assertInstanceOf(FileTokenRetriever.class, handler.getAccessTokenRetriever());
}
@Test
public void testConfigureWithAccessClientCredentials() {
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, "http://www.example.com");
System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, "http://www.example.com");
Map<String, Object> jaasConfigs = new HashMap<>();
jaasConfigs.put(CLIENT_ID_CONFIG, "an ID");
jaasConfigs.put(CLIENT_SECRET_CONFIG, "a secret");
configureHandler(handler, configs, jaasConfigs);
assertInstanceOf(HttpAccessTokenRetriever.class, handler.getAccessTokenRetriever());
} }
private void testInvalidAccessToken(String accessToken, String expectedMessageSubstring) throws Exception { private void testInvalidAccessToken(String accessToken, String expectedMessageSubstring) throws Exception {
Map<String, ?> configs = getSaslConfigs(); Map<String, ?> configs = getSaslConfigs();
OAuthBearerLoginCallbackHandler handler = createHandler(() -> accessToken, configs); JwtRetriever jwtRetriever = () -> accessToken;
JwtValidator jwtValidator = createJwtValidator(configs);
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
handler.init(Map.of(), jwtRetriever, jwtValidator);
try { try {
OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback(); OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback();
@ -260,19 +246,15 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest {
} }
} }
private String createAccessKey(String header, String payload, String signature) { private static DefaultJwtRetriever createJwtRetriever(Map<String, ?> configs) {
Base64.Encoder enc = Base64.getEncoder(); return createJwtRetriever(configs, Map.of());
header = enc.encodeToString(Utils.utf8(header));
payload = enc.encodeToString(Utils.utf8(payload));
signature = enc.encodeToString(Utils.utf8(signature));
return String.format("%s.%s.%s", header, payload, signature);
} }
private OAuthBearerLoginCallbackHandler createHandler(AccessTokenRetriever accessTokenRetriever, Map<String, ?> configs) { private static DefaultJwtRetriever createJwtRetriever(Map<String, ?> configs, Map<String, Object> jaasConfigs) {
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler(); return new DefaultJwtRetriever(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, jaasConfigs);
AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(configs);
handler.init(accessTokenRetriever, accessTokenValidator);
return handler;
} }
private static DefaultJwtValidator createJwtValidator(Map<String, ?> configs) {
return new DefaultJwtValidator(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM);
}
} }

View File

@ -17,27 +17,30 @@
package org.apache.kafka.common.security.oauthbearer; package org.apache.kafka.common.security.oauthbearer;
import org.apache.kafka.common.KafkaException;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenBuilder; import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenBuilder;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidatorFactory;
import org.apache.kafka.common.security.oauthbearer.internals.secured.CloseableVerificationKeyResolver; import org.apache.kafka.common.security.oauthbearer.internals.secured.CloseableVerificationKeyResolver;
import org.apache.kafka.common.security.oauthbearer.internals.secured.DefaultJwtValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.JwtValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.OAuthBearerTest; import org.apache.kafka.common.security.oauthbearer.internals.secured.OAuthBearerTest;
import org.apache.kafka.common.utils.Utils; import org.apache.kafka.common.security.oauthbearer.internals.secured.ValidateException;
import org.jose4j.jws.AlgorithmIdentifiers; import org.jose4j.jws.AlgorithmIdentifiers;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.Base64;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import javax.security.auth.callback.Callback; import javax.security.auth.callback.Callback;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_EXPECTED_AUDIENCE; import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_EXPECTED_AUDIENCE;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
public class OAuthBearerValidatorCallbackHandlerTest extends OAuthBearerTest { public class OAuthBearerValidatorCallbackHandlerTest extends OAuthBearerTest {
@ -53,7 +56,10 @@ public class OAuthBearerValidatorCallbackHandlerTest extends OAuthBearerTest {
String accessToken = builder.build(); String accessToken = builder.build();
Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_EXPECTED_AUDIENCE, allAudiences); Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_EXPECTED_AUDIENCE, allAudiences);
OAuthBearerValidatorCallbackHandler handler = createHandler(configs, builder); CloseableVerificationKeyResolver verificationKeyResolver = createVerificationKeyResolver(builder);
JwtValidator jwtValidator = createJwtValidator(configs, verificationKeyResolver);
OAuthBearerValidatorCallbackHandler handler = new OAuthBearerValidatorCallbackHandler();
handler.init(verificationKeyResolver, jwtValidator);
try { try {
OAuthBearerValidatorCallback callback = new OAuthBearerValidatorCallback(accessToken); OAuthBearerValidatorCallback callback = new OAuthBearerValidatorCallback(accessToken);
@ -81,9 +87,68 @@ public class OAuthBearerValidatorCallbackHandlerTest extends OAuthBearerTest {
assertInvalidAccessTokenFails(createAccessKey("{}", "{}", "{}"), substring); assertInvalidAccessTokenFails(createAccessKey("{}", "{}", "{}"), substring);
} }
@Test
public void testHandlerInitThrowsException() throws IOException {
IOException initError = new IOException("init() error");
AccessTokenBuilder builder = new AccessTokenBuilder()
.alg(AlgorithmIdentifiers.RSA_USING_SHA256);
CloseableVerificationKeyResolver verificationKeyResolver = createVerificationKeyResolver(builder);
JwtValidator jwtValidator = new JwtValidator() {
@Override
public void init() throws IOException {
throw initError;
}
@Override
public OAuthBearerToken validate(String accessToken) throws ValidateException {
return null;
}
};
OAuthBearerValidatorCallbackHandler handler = new OAuthBearerValidatorCallbackHandler();
// An error initializing the JwtValidator should cause OAuthBearerValidatorCallbackHandler.init() to fail.
KafkaException root = assertThrows(
KafkaException.class,
() -> handler.init(verificationKeyResolver, jwtValidator)
);
assertNotNull(root.getCause());
assertEquals(initError, root.getCause());
}
@Test
public void testHandlerCloseDoesNotThrowException() throws IOException {
AccessTokenBuilder builder = new AccessTokenBuilder()
.alg(AlgorithmIdentifiers.RSA_USING_SHA256);
CloseableVerificationKeyResolver verificationKeyResolver = createVerificationKeyResolver(builder);
JwtValidator jwtValidator = new JwtValidator() {
@Override
public void close() throws IOException {
throw new IOException("close() error");
}
@Override
public OAuthBearerToken validate(String accessToken) throws ValidateException {
return null;
}
};
OAuthBearerValidatorCallbackHandler handler = new OAuthBearerValidatorCallbackHandler();
handler.init(verificationKeyResolver, jwtValidator);
// An error closings the JwtValidator should *not* cause OAuthBearerValidatorCallbackHandler.close() to fail.
assertDoesNotThrow(handler::close);
}
private void assertInvalidAccessTokenFails(String accessToken, String expectedMessageSubstring) throws Exception { private void assertInvalidAccessTokenFails(String accessToken, String expectedMessageSubstring) throws Exception {
AccessTokenBuilder builder = new AccessTokenBuilder()
.alg(AlgorithmIdentifiers.RSA_USING_SHA256);
Map<String, ?> configs = getSaslConfigs(); Map<String, ?> configs = getSaslConfigs();
OAuthBearerValidatorCallbackHandler handler = createHandler(configs, new AccessTokenBuilder()); CloseableVerificationKeyResolver verificationKeyResolver = createVerificationKeyResolver(builder);
JwtValidator jwtValidator = createJwtValidator(configs, verificationKeyResolver);
OAuthBearerValidatorCallbackHandler handler = new OAuthBearerValidatorCallbackHandler();
handler.init(verificationKeyResolver, jwtValidator);
try { try {
OAuthBearerValidatorCallback callback = new OAuthBearerValidatorCallback(accessToken); OAuthBearerValidatorCallback callback = new OAuthBearerValidatorCallback(accessToken);
@ -98,22 +163,11 @@ public class OAuthBearerValidatorCallbackHandlerTest extends OAuthBearerTest {
} }
} }
private OAuthBearerValidatorCallbackHandler createHandler(Map<String, ?> options, private JwtValidator createJwtValidator(Map<String, ?> configs, CloseableVerificationKeyResolver verificationKeyResolver) {
AccessTokenBuilder builder) { return new DefaultJwtValidator(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, verificationKeyResolver);
OAuthBearerValidatorCallbackHandler handler = new OAuthBearerValidatorCallbackHandler();
CloseableVerificationKeyResolver verificationKeyResolver = (jws, nestingContext) ->
builder.jwk().getPublicKey();
AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(options, verificationKeyResolver);
handler.init(verificationKeyResolver, accessTokenValidator);
return handler;
} }
private String createAccessKey(String header, String payload, String signature) { private CloseableVerificationKeyResolver createVerificationKeyResolver(AccessTokenBuilder builder) {
Base64.Encoder enc = Base64.getEncoder(); return (jws, nestingContext) -> builder.jwk().getPublicKey();
header = enc.encodeToString(Utils.utf8(header));
payload = enc.encodeToString(Utils.utf8(payload));
signature = enc.encodeToString(Utils.utf8(signature));
return String.format("%s.%s.%s", header, payload, signature);
} }
} }

View File

@ -1,73 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.common.security.oauthbearer.internals.secured;
import org.apache.kafka.common.KafkaException;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler;
import org.junit.jupiter.api.Test;
import java.io.IOException;
import java.util.Map;
public class AccessTokenValidatorFactoryTest extends OAuthBearerTest {
@Test
public void testConfigureThrowsExceptionOnAccessTokenValidatorInit() {
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
AccessTokenRetriever accessTokenRetriever = new AccessTokenRetriever() {
@Override
public void init() throws IOException {
throw new IOException("My init had an error!");
}
@Override
public String retrieve() {
return "dummy";
}
};
Map<String, ?> configs = getSaslConfigs();
AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(configs);
assertThrowsWithMessage(
KafkaException.class, () -> handler.init(accessTokenRetriever, accessTokenValidator), "encountered an error when initializing");
}
@Test
public void testConfigureThrowsExceptionOnAccessTokenValidatorClose() {
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
AccessTokenRetriever accessTokenRetriever = new AccessTokenRetriever() {
@Override
public void close() throws IOException {
throw new IOException("My close had an error!");
}
@Override
public String retrieve() {
return "dummy";
}
};
Map<String, ?> configs = getSaslConfigs();
AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(configs);
handler.init(accessTokenRetriever, accessTokenValidator);
// Basically asserting this doesn't throw an exception :(
handler.close();
}
}

View File

@ -28,11 +28,11 @@ import java.util.Collections;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
public class ValidatorAccessTokenValidatorTest extends AccessTokenValidatorTest { public class BrokerJwtValidatorTest extends JwtValidatorTest {
@Override @Override
protected AccessTokenValidator createAccessTokenValidator(AccessTokenBuilder builder) { protected JwtValidator createJwtValidator(AccessTokenBuilder builder) {
return new ValidatorAccessTokenValidator(30, return new BrokerJwtValidator(30,
Collections.emptySet(), Collections.emptySet(),
null, null,
(jws, nestingContext) -> builder.jwk().getKey(), (jws, nestingContext) -> builder.jwk().getKey(),
@ -72,7 +72,7 @@ public class ValidatorAccessTokenValidatorTest extends AccessTokenValidatorTest
.addCustomClaim(subClaimName, subject) .addCustomClaim(subClaimName, subject)
.subjectClaimName(subClaimName) .subjectClaimName(subClaimName)
.subject(null); .subject(null);
AccessTokenValidator validator = createAccessTokenValidator(tokenBuilder); JwtValidator validator = createJwtValidator(tokenBuilder);
// Validation should succeed (e.g. signature verification) even if sub claim is missing // Validation should succeed (e.g. signature verification) even if sub claim is missing
OAuthBearerToken token = validator.validate(tokenBuilder.build()); OAuthBearerToken token = validator.validate(tokenBuilder.build());
@ -82,7 +82,7 @@ public class ValidatorAccessTokenValidatorTest extends AccessTokenValidatorTest
private void testEncryptionAlgorithm(PublicJsonWebKey jwk, String alg) throws Exception { private void testEncryptionAlgorithm(PublicJsonWebKey jwk, String alg) throws Exception {
AccessTokenBuilder builder = new AccessTokenBuilder().jwk(jwk).alg(alg); AccessTokenBuilder builder = new AccessTokenBuilder().jwk(jwk).alg(alg);
AccessTokenValidator validator = createAccessTokenValidator(builder); JwtValidator validator = createJwtValidator(builder);
String accessToken = builder.build(); String accessToken = builder.build();
OAuthBearerToken token = validator.validate(accessToken); OAuthBearerToken token = validator.validate(accessToken);

View File

@ -17,11 +17,11 @@
package org.apache.kafka.common.security.oauthbearer.internals.secured; package org.apache.kafka.common.security.oauthbearer.internals.secured;
public class LoginAccessTokenValidatorTest extends AccessTokenValidatorTest { public class ClientJwtValidatorTest extends JwtValidatorTest {
@Override @Override
protected AccessTokenValidator createAccessTokenValidator(AccessTokenBuilder builder) { protected JwtValidator createJwtValidator(AccessTokenBuilder builder) {
return new LoginAccessTokenValidator(builder.scopeClaimName(), builder.subjectClaimName()); return new ClientJwtValidator(builder.scopeClaimName(), builder.subjectClaimName());
} }
} }

View File

@ -18,6 +18,7 @@
package org.apache.kafka.common.security.oauthbearer.internals.secured; package org.apache.kafka.common.security.oauthbearer.internals.secured;
import org.apache.kafka.common.config.ConfigException; import org.apache.kafka.common.config.ConfigException;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
@ -26,7 +27,9 @@ import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource; import org.junit.jupiter.params.provider.MethodSource;
import java.io.File; import java.io.File;
import java.io.IOException;
import java.util.Collections; import java.util.Collections;
import java.util.HashMap;
import java.util.Map; import java.util.Map;
import java.util.stream.Stream; import java.util.stream.Stream;
@ -34,9 +37,13 @@ import static org.apache.kafka.common.config.SaslConfigs.DEFAULT_SASL_OAUTHBEARE
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_HEADER_URLENCODE; import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_HEADER_URLENCODE;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL; import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL;
import static org.apache.kafka.common.config.internals.BrokerSecurityConfigs.ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG; import static org.apache.kafka.common.config.internals.BrokerSecurityConfigs.ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG;
import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler.CLIENT_ID_CONFIG;
import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler.CLIENT_SECRET_CONFIG;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
public class AccessTokenRetrieverFactoryTest extends OAuthBearerTest { public class DefaultJwtRetrieverTest extends OAuthBearerTest {
@AfterEach @AfterEach
public void tearDown() throws Exception { public void tearDown() throws Exception {
@ -44,7 +51,7 @@ public class AccessTokenRetrieverFactoryTest extends OAuthBearerTest {
} }
@Test @Test
public void testConfigureRefreshingFileAccessTokenRetriever() throws Exception { public void testConfigureRefreshingFileJwtRetriever() throws Exception {
String expected = "{}"; String expected = "{}";
File tmpDir = createTempDir("access-token"); File tmpDir = createTempDir("access-token");
@ -54,31 +61,37 @@ public class AccessTokenRetrieverFactoryTest extends OAuthBearerTest {
Map<String, ?> configs = Collections.singletonMap(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, accessTokenFile.toURI().toString()); Map<String, ?> configs = Collections.singletonMap(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, accessTokenFile.toURI().toString());
Map<String, Object> jaasConfig = Collections.emptyMap(); Map<String, Object> jaasConfig = Collections.emptyMap();
try (AccessTokenRetriever accessTokenRetriever = AccessTokenRetrieverFactory.create(configs, jaasConfig)) { try (JwtRetriever jwtRetriever = new DefaultJwtRetriever(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, jaasConfig)) {
accessTokenRetriever.init(); jwtRetriever.init();
assertEquals(expected, accessTokenRetriever.retrieve()); assertEquals(expected, jwtRetriever.retrieve());
} }
} }
@Test @Test
public void testConfigureRefreshingFileAccessTokenRetrieverWithInvalidDirectory() { public void testConfigureRefreshingFileJwtRetrieverWithInvalidDirectory() throws IOException {
// Should fail because the parent path doesn't exist. // Should fail because the parent path doesn't exist.
String file = new File("/tmp/this-directory-does-not-exist/foo.json").toURI().toString(); String file = new File("/tmp/this-directory-does-not-exist/foo.json").toURI().toString();
System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, file); System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, file);
Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, file); Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, file);
Map<String, Object> jaasConfig = Collections.emptyMap(); Map<String, Object> jaasConfig = Collections.emptyMap();
assertThrowsWithMessage(ConfigException.class, () -> AccessTokenRetrieverFactory.create(configs, jaasConfig), "that doesn't exist");
try (JwtRetriever jwtRetriever = new DefaultJwtRetriever(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, jaasConfig)) {
assertThrowsWithMessage(ConfigException.class, jwtRetriever::init, "that doesn't exist");
}
} }
@Test @Test
public void testConfigureRefreshingFileAccessTokenRetrieverWithInvalidFile() throws Exception { public void testConfigureRefreshingFileJwtRetrieverWithInvalidFile() throws Exception {
// Should fail because while the parent path exists, the file itself doesn't. // Should fail because while the parent path exists, the file itself doesn't.
File tmpDir = createTempDir("this-directory-does-exist"); File tmpDir = createTempDir("this-directory-does-exist");
File accessTokenFile = new File(tmpDir, "this-file-does-not-exist.json"); File accessTokenFile = new File(tmpDir, "this-file-does-not-exist.json");
System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, accessTokenFile.toURI().toString()); System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, accessTokenFile.toURI().toString());
Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, accessTokenFile.toURI().toString()); Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, accessTokenFile.toURI().toString());
Map<String, Object> jaasConfig = Collections.emptyMap(); Map<String, Object> jaasConfig = Collections.emptyMap();
assertThrowsWithMessage(ConfigException.class, () -> AccessTokenRetrieverFactory.create(configs, jaasConfig), "that doesn't exist");
try (JwtRetriever jwtRetriever = new DefaultJwtRetriever(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, jaasConfig)) {
assertThrowsWithMessage(ConfigException.class, jwtRetriever::init, "that doesn't exist");
}
} }
@Test @Test
@ -87,15 +100,53 @@ public class AccessTokenRetrieverFactoryTest extends OAuthBearerTest {
File tmpDir = createTempDir("not_allowed"); File tmpDir = createTempDir("not_allowed");
File accessTokenFile = new File(tmpDir, "not_allowed.json"); File accessTokenFile = new File(tmpDir, "not_allowed.json");
Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, accessTokenFile.toURI().toString()); Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, accessTokenFile.toURI().toString());
assertThrowsWithMessage(ConfigException.class, () -> AccessTokenRetrieverFactory.create(configs, Collections.emptyMap()),
ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG); try (JwtRetriever jwtRetriever = new DefaultJwtRetriever(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, Collections.emptyMap())) {
assertThrowsWithMessage(ConfigException.class, jwtRetriever::init, ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG);
}
}
@Test
public void testConfigureWithAccessTokenFile() throws Exception {
String expected = "{}";
File tmpDir = createTempDir("access-token");
File accessTokenFile = createTempFile(tmpDir, "access-token-", ".json", expected);
System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, accessTokenFile.toURI().toString());
Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, accessTokenFile.toURI().toString());
DefaultJwtRetriever jwtRetriever = new DefaultJwtRetriever(
configs,
OAuthBearerLoginModule.OAUTHBEARER_MECHANISM,
Map.of()
);
assertDoesNotThrow(jwtRetriever::init);
assertInstanceOf(FileJwtRetriever.class, jwtRetriever.delegate());
}
@Test
public void testConfigureWithAccessClientCredentials() {
Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, "http://www.example.com");
System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, "http://www.example.com");
Map<String, Object> jaasConfigs = new HashMap<>();
jaasConfigs.put(CLIENT_ID_CONFIG, "an ID");
jaasConfigs.put(CLIENT_SECRET_CONFIG, "a secret");
DefaultJwtRetriever jwtRetriever = new DefaultJwtRetriever(
configs,
OAuthBearerLoginModule.OAUTHBEARER_MECHANISM,
jaasConfigs
);
assertDoesNotThrow(jwtRetriever::init);
assertInstanceOf(HttpJwtRetriever.class, jwtRetriever.delegate());
} }
@ParameterizedTest @ParameterizedTest
@MethodSource("urlencodeHeaderSupplier") @MethodSource("urlencodeHeaderSupplier")
public void testUrlencodeHeader(Map<String, Object> configs, boolean expectedValue) { public void testUrlencodeHeader(Map<String, Object> configs, boolean expectedValue) {
ConfigurationUtils cu = new ConfigurationUtils(configs); ConfigurationUtils cu = new ConfigurationUtils(configs);
boolean actualValue = AccessTokenRetrieverFactory.validateUrlencodeHeader(cu); boolean actualValue = DefaultJwtRetriever.validateUrlencodeHeader(cu);
assertEquals(expectedValue, actualValue); assertEquals(expectedValue, actualValue);
} }

View File

@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.common.security.oauthbearer.internals.secured;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
import org.jose4j.jws.AlgorithmIdentifiers;
import org.junit.jupiter.api.Test;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
public class DefaultJwtValidatorTest extends OAuthBearerTest {
@Test
public void testConfigureWithVerificationKeyResolver() {
AccessTokenBuilder builder = new AccessTokenBuilder()
.alg(AlgorithmIdentifiers.RSA_USING_SHA256);
CloseableVerificationKeyResolver verificationKeyResolver = createVerificationKeyResolver(builder);
Map<String, ?> configs = getSaslConfigs();
DefaultJwtValidator jwtValidator = new DefaultJwtValidator(
configs,
OAuthBearerLoginModule.OAUTHBEARER_MECHANISM,
verificationKeyResolver
);
assertDoesNotThrow(jwtValidator::init);
assertInstanceOf(BrokerJwtValidator.class, jwtValidator.delegate());
}
@Test
public void testConfigureWithoutVerificationKeyResolver() {
Map<String, ?> configs = getSaslConfigs();
DefaultJwtValidator jwtValidator = new DefaultJwtValidator(
configs,
OAuthBearerLoginModule.OAUTHBEARER_MECHANISM
);
assertDoesNotThrow(jwtValidator::init);
assertInstanceOf(ClientJwtValidator.class, jwtValidator.delegate());
}
private CloseableVerificationKeyResolver createVerificationKeyResolver(AccessTokenBuilder builder) {
return (jws, nestingContext) -> builder.jwk().getPublicKey();
}
}

View File

@ -39,20 +39,20 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when; import static org.mockito.Mockito.when;
public class HttpAccessTokenRetrieverTest extends OAuthBearerTest { public class HttpJwtRetrieverTest extends OAuthBearerTest {
@Test @Test
public void test() throws IOException { public void test() throws IOException {
String expectedResponse = "Hiya, buddy"; String expectedResponse = "Hiya, buddy";
HttpURLConnection mockedCon = createHttpURLConnection(expectedResponse); HttpURLConnection mockedCon = createHttpURLConnection(expectedResponse);
String response = HttpAccessTokenRetriever.post(mockedCon, null, null, null, null); String response = HttpJwtRetriever.post(mockedCon, null, null, null, null);
assertEquals(expectedResponse, response); assertEquals(expectedResponse, response);
} }
@Test @Test
public void testEmptyResponse() throws IOException { public void testEmptyResponse() throws IOException {
HttpURLConnection mockedCon = createHttpURLConnection(""); HttpURLConnection mockedCon = createHttpURLConnection("");
assertThrows(IOException.class, () -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null)); assertThrows(IOException.class, () -> HttpJwtRetriever.post(mockedCon, null, null, null, null));
} }
@Test @Test
@ -60,7 +60,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
HttpURLConnection mockedCon = createHttpURLConnection("dummy"); HttpURLConnection mockedCon = createHttpURLConnection("dummy");
when(mockedCon.getInputStream()).thenThrow(new IOException("Can't read")); when(mockedCon.getInputStream()).thenThrow(new IOException("Can't read"));
assertThrows(IOException.class, () -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null)); assertThrows(IOException.class, () -> HttpJwtRetriever.post(mockedCon, null, null, null, null));
} }
@Test @Test
@ -72,7 +72,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
.getBytes(StandardCharsets.UTF_8))); .getBytes(StandardCharsets.UTF_8)));
when(mockedCon.getResponseCode()).thenReturn(HttpURLConnection.HTTP_BAD_REQUEST); when(mockedCon.getResponseCode()).thenReturn(HttpURLConnection.HTTP_BAD_REQUEST);
UnretryableException ioe = assertThrows(UnretryableException.class, UnretryableException ioe = assertThrows(UnretryableException.class,
() -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null)); () -> HttpJwtRetriever.post(mockedCon, null, null, null, null));
assertTrue(ioe.getMessage().contains("{\"some_arg\" - \"some problem with arg\"}")); assertTrue(ioe.getMessage().contains("{\"some_arg\" - \"some problem with arg\"}"));
} }
@ -85,7 +85,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
.getBytes(StandardCharsets.UTF_8))); .getBytes(StandardCharsets.UTF_8)));
when(mockedCon.getResponseCode()).thenReturn(HttpURLConnection.HTTP_INTERNAL_ERROR); when(mockedCon.getResponseCode()).thenReturn(HttpURLConnection.HTTP_INTERNAL_ERROR);
IOException ioe = assertThrows(IOException.class, IOException ioe = assertThrows(IOException.class,
() -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null)); () -> HttpJwtRetriever.post(mockedCon, null, null, null, null));
assertTrue(ioe.getMessage().contains("{\"some_arg\" - \"some problem with arg\"}")); assertTrue(ioe.getMessage().contains("{\"some_arg\" - \"some problem with arg\"}"));
// error response body has different keys // error response body has different keys
@ -93,7 +93,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
"{\"errorCode\":\"some_arg\", \"errorSummary\":\"some problem with arg\"}" "{\"errorCode\":\"some_arg\", \"errorSummary\":\"some problem with arg\"}"
.getBytes(StandardCharsets.UTF_8))); .getBytes(StandardCharsets.UTF_8)));
ioe = assertThrows(IOException.class, ioe = assertThrows(IOException.class,
() -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null)); () -> HttpJwtRetriever.post(mockedCon, null, null, null, null));
assertTrue(ioe.getMessage().contains("{\"some_arg\" - \"some problem with arg\"}")); assertTrue(ioe.getMessage().contains("{\"some_arg\" - \"some problem with arg\"}"));
// error response is valid json but unknown keys // error response is valid json but unknown keys
@ -101,7 +101,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
"{\"err\":\"some_arg\", \"err_des\":\"some problem with arg\"}" "{\"err\":\"some_arg\", \"err_des\":\"some problem with arg\"}"
.getBytes(StandardCharsets.UTF_8))); .getBytes(StandardCharsets.UTF_8)));
ioe = assertThrows(IOException.class, ioe = assertThrows(IOException.class,
() -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null)); () -> HttpJwtRetriever.post(mockedCon, null, null, null, null));
assertTrue(ioe.getMessage().contains("{\"err\":\"some_arg\", \"err_des\":\"some problem with arg\"}")); assertTrue(ioe.getMessage().contains("{\"err\":\"some_arg\", \"err_des\":\"some problem with arg\"}"));
} }
@ -113,7 +113,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
"non json error output".getBytes(StandardCharsets.UTF_8))); "non json error output".getBytes(StandardCharsets.UTF_8)));
when(mockedCon.getResponseCode()).thenReturn(HttpURLConnection.HTTP_INTERNAL_ERROR); when(mockedCon.getResponseCode()).thenReturn(HttpURLConnection.HTTP_INTERNAL_ERROR);
IOException ioe = assertThrows(IOException.class, IOException ioe = assertThrows(IOException.class,
() -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null)); () -> HttpJwtRetriever.post(mockedCon, null, null, null, null));
assertTrue(ioe.getMessage().contains("{non json error output}")); assertTrue(ioe.getMessage().contains("{non json error output}"));
} }
@ -124,7 +124,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
r.nextBytes(expected); r.nextBytes(expected);
InputStream in = new ByteArrayInputStream(expected); InputStream in = new ByteArrayInputStream(expected);
ByteArrayOutputStream out = new ByteArrayOutputStream(); ByteArrayOutputStream out = new ByteArrayOutputStream();
HttpAccessTokenRetriever.copy(in, out); HttpJwtRetriever.copy(in, out);
assertArrayEquals(expected, out.toByteArray()); assertArrayEquals(expected, out.toByteArray());
} }
@ -133,7 +133,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
InputStream mockedIn = mock(InputStream.class); InputStream mockedIn = mock(InputStream.class);
OutputStream out = new ByteArrayOutputStream(); OutputStream out = new ByteArrayOutputStream();
when(mockedIn.read(any(byte[].class))).thenThrow(new IOException()); when(mockedIn.read(any(byte[].class))).thenThrow(new IOException());
assertThrows(IOException.class, () -> HttpAccessTokenRetriever.copy(mockedIn, out)); assertThrows(IOException.class, () -> HttpJwtRetriever.copy(mockedIn, out));
} }
@Test @Test
@ -143,7 +143,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
ObjectNode node = mapper.createObjectNode(); ObjectNode node = mapper.createObjectNode();
node.put("access_token", expected); node.put("access_token", expected);
String actual = HttpAccessTokenRetriever.parseAccessToken(mapper.writeValueAsString(node)); String actual = HttpJwtRetriever.parseAccessToken(mapper.writeValueAsString(node));
assertEquals(expected, actual); assertEquals(expected, actual);
} }
@ -153,7 +153,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
ObjectNode node = mapper.createObjectNode(); ObjectNode node = mapper.createObjectNode();
node.put("access_token", ""); node.put("access_token", "");
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.parseAccessToken(mapper.writeValueAsString(node))); assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.parseAccessToken(mapper.writeValueAsString(node)));
} }
@Test @Test
@ -162,12 +162,12 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
ObjectNode node = mapper.createObjectNode(); ObjectNode node = mapper.createObjectNode();
node.put("sub", "jdoe"); node.put("sub", "jdoe");
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.parseAccessToken(mapper.writeValueAsString(node))); assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.parseAccessToken(mapper.writeValueAsString(node)));
} }
@Test @Test
public void testParseAccessTokenInvalidJson() { public void testParseAccessTokenInvalidJson() {
assertThrows(IOException.class, () -> HttpAccessTokenRetriever.parseAccessToken("not valid JSON")); assertThrows(IOException.class, () -> HttpJwtRetriever.parseAccessToken("not valid JSON"));
} }
@Test @Test
@ -184,27 +184,27 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
} }
private void assertAuthorizationHeader(String clientId, String clientSecret, boolean urlencode, String expected) { private void assertAuthorizationHeader(String clientId, String clientSecret, boolean urlencode, String expected) {
String actual = HttpAccessTokenRetriever.formatAuthorizationHeader(clientId, clientSecret, urlencode); String actual = HttpJwtRetriever.formatAuthorizationHeader(clientId, clientSecret, urlencode);
assertEquals(expected, actual, String.format("Expected the HTTP Authorization header generated for client ID \"%s\" and client secret \"%s\" to match", clientId, clientSecret)); assertEquals(expected, actual, String.format("Expected the HTTP Authorization header generated for client ID \"%s\" and client secret \"%s\" to match", clientId, clientSecret));
} }
@Test @Test
public void testFormatAuthorizationHeaderMissingValues() { public void testFormatAuthorizationHeaderMissingValues() {
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader(null, "secret", false)); assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader(null, "secret", false));
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader("id", null, false)); assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader("id", null, false));
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader(null, null, false)); assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader(null, null, false));
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader("", "secret", false)); assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader("", "secret", false));
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader("id", "", false)); assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader("id", "", false));
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader("", "", false)); assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader("", "", false));
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader(" ", "secret", false)); assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader(" ", "secret", false));
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader("id", " ", false)); assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader("id", " ", false));
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader(" ", " ", false)); assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader(" ", " ", false));
} }
@Test @Test
public void testFormatRequestBody() { public void testFormatRequestBody() {
String expected = "grant_type=client_credentials&scope=scope"; String expected = "grant_type=client_credentials&scope=scope";
String actual = HttpAccessTokenRetriever.formatRequestBody("scope"); String actual = HttpJwtRetriever.formatRequestBody("scope");
assertEquals(expected, actual); assertEquals(expected, actual);
} }
@ -214,24 +214,24 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
String exclamationMark = "%21"; String exclamationMark = "%21";
String expected = String.format("grant_type=client_credentials&scope=earth+is+great%s", exclamationMark); String expected = String.format("grant_type=client_credentials&scope=earth+is+great%s", exclamationMark);
String actual = HttpAccessTokenRetriever.formatRequestBody("earth is great!"); String actual = HttpJwtRetriever.formatRequestBody("earth is great!");
assertEquals(expected, actual); assertEquals(expected, actual);
expected = String.format("grant_type=client_credentials&scope=what+on+earth%s%s%s%s%s", questionMark, exclamationMark, questionMark, exclamationMark, questionMark); expected = String.format("grant_type=client_credentials&scope=what+on+earth%s%s%s%s%s", questionMark, exclamationMark, questionMark, exclamationMark, questionMark);
actual = HttpAccessTokenRetriever.formatRequestBody("what on earth?!?!?"); actual = HttpJwtRetriever.formatRequestBody("what on earth?!?!?");
assertEquals(expected, actual); assertEquals(expected, actual);
} }
@Test @Test
public void testFormatRequestBodyMissingValues() { public void testFormatRequestBodyMissingValues() {
String expected = "grant_type=client_credentials"; String expected = "grant_type=client_credentials";
String actual = HttpAccessTokenRetriever.formatRequestBody(null); String actual = HttpJwtRetriever.formatRequestBody(null);
assertEquals(expected, actual); assertEquals(expected, actual);
actual = HttpAccessTokenRetriever.formatRequestBody(""); actual = HttpJwtRetriever.formatRequestBody("");
assertEquals(expected, actual); assertEquals(expected, actual);
actual = HttpAccessTokenRetriever.formatRequestBody(" "); actual = HttpJwtRetriever.formatRequestBody(" ");
assertEquals(expected, actual); assertEquals(expected, actual);
} }

View File

@ -26,42 +26,42 @@ import org.junit.jupiter.api.TestInstance.Lifecycle;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
@TestInstance(Lifecycle.PER_CLASS) @TestInstance(Lifecycle.PER_CLASS)
public abstract class AccessTokenValidatorTest extends OAuthBearerTest { public abstract class JwtValidatorTest extends OAuthBearerTest {
protected abstract AccessTokenValidator createAccessTokenValidator(AccessTokenBuilder accessTokenBuilder) throws Exception; protected abstract JwtValidator createJwtValidator(AccessTokenBuilder accessTokenBuilder) throws Exception;
protected AccessTokenValidator createAccessTokenValidator() throws Exception { protected JwtValidator createJwtValidator() throws Exception {
AccessTokenBuilder builder = new AccessTokenBuilder(); AccessTokenBuilder builder = new AccessTokenBuilder();
return createAccessTokenValidator(builder); return createJwtValidator(builder);
} }
@Test @Test
public void testNull() throws Exception { public void testNull() throws Exception {
AccessTokenValidator validator = createAccessTokenValidator(); JwtValidator validator = createJwtValidator();
assertThrowsWithMessage(ValidateException.class, () -> validator.validate(null), "Malformed JWT provided; expected three sections (header, payload, and signature)"); assertThrowsWithMessage(ValidateException.class, () -> validator.validate(null), "Malformed JWT provided; expected three sections (header, payload, and signature)");
} }
@Test @Test
public void testEmptyString() throws Exception { public void testEmptyString() throws Exception {
AccessTokenValidator validator = createAccessTokenValidator(); JwtValidator validator = createJwtValidator();
assertThrowsWithMessage(ValidateException.class, () -> validator.validate(""), "Malformed JWT provided; expected three sections (header, payload, and signature)"); assertThrowsWithMessage(ValidateException.class, () -> validator.validate(""), "Malformed JWT provided; expected three sections (header, payload, and signature)");
} }
@Test @Test
public void testWhitespace() throws Exception { public void testWhitespace() throws Exception {
AccessTokenValidator validator = createAccessTokenValidator(); JwtValidator validator = createJwtValidator();
assertThrowsWithMessage(ValidateException.class, () -> validator.validate(" "), "Malformed JWT provided; expected three sections (header, payload, and signature)"); assertThrowsWithMessage(ValidateException.class, () -> validator.validate(" "), "Malformed JWT provided; expected three sections (header, payload, and signature)");
} }
@Test @Test
public void testEmptySections() throws Exception { public void testEmptySections() throws Exception {
AccessTokenValidator validator = createAccessTokenValidator(); JwtValidator validator = createJwtValidator();
assertThrowsWithMessage(ValidateException.class, () -> validator.validate(".."), "Malformed JWT provided; expected three sections (header, payload, and signature)"); assertThrowsWithMessage(ValidateException.class, () -> validator.validate(".."), "Malformed JWT provided; expected three sections (header, payload, and signature)");
} }
@Test @Test
public void testMissingHeader() throws Exception { public void testMissingHeader() throws Exception {
AccessTokenValidator validator = createAccessTokenValidator(); JwtValidator validator = createJwtValidator();
String header = ""; String header = "";
String payload = createBase64JsonJwtSection(node -> { }); String payload = createBase64JsonJwtSection(node -> { });
String signature = ""; String signature = "";
@ -71,7 +71,7 @@ public abstract class AccessTokenValidatorTest extends OAuthBearerTest {
@Test @Test
public void testMissingPayload() throws Exception { public void testMissingPayload() throws Exception {
AccessTokenValidator validator = createAccessTokenValidator(); JwtValidator validator = createJwtValidator();
String header = createBase64JsonJwtSection(node -> node.put(HeaderParameterNames.ALGORITHM, AlgorithmIdentifiers.NONE)); String header = createBase64JsonJwtSection(node -> node.put(HeaderParameterNames.ALGORITHM, AlgorithmIdentifiers.NONE));
String payload = ""; String payload = "";
String signature = ""; String signature = "";
@ -81,7 +81,7 @@ public abstract class AccessTokenValidatorTest extends OAuthBearerTest {
@Test @Test
public void testMissingSignature() throws Exception { public void testMissingSignature() throws Exception {
AccessTokenValidator validator = createAccessTokenValidator(); JwtValidator validator = createJwtValidator();
String header = createBase64JsonJwtSection(node -> node.put(HeaderParameterNames.ALGORITHM, AlgorithmIdentifiers.NONE)); String header = createBase64JsonJwtSection(node -> node.put(HeaderParameterNames.ALGORITHM, AlgorithmIdentifiers.NONE));
String payload = createBase64JsonJwtSection(node -> { }); String payload = createBase64JsonJwtSection(node -> { });
String signature = ""; String signature = "";

View File

@ -19,9 +19,6 @@ package org.apache.kafka.common.security.oauthbearer.internals.secured;
import org.apache.kafka.common.config.AbstractConfig; import org.apache.kafka.common.config.AbstractConfig;
import org.apache.kafka.common.config.ConfigDef; import org.apache.kafka.common.config.ConfigDef;
import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
import org.apache.kafka.common.security.authenticator.TestJaasConfig;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
import org.apache.kafka.common.utils.Utils; import org.apache.kafka.common.utils.Utils;
import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectMapper;
@ -52,8 +49,6 @@ import java.util.Map;
import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutionException;
import java.util.function.Consumer; import java.util.function.Consumer;
import javax.security.auth.login.AppConfigurationEntry;
import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assertions.fail;
@ -80,18 +75,6 @@ public abstract class OAuthBearerTest {
expectedSubstring)); expectedSubstring));
} }
protected void configureHandler(AuthenticateCallbackHandler handler,
Map<String, ?> configs,
Map<String, Object> jaasConfig) {
TestJaasConfig config = new TestJaasConfig();
config.createOrUpdateEntry("KafkaClient", OAuthBearerLoginModule.class.getName(), jaasConfig);
AppConfigurationEntry kafkaClient = config.getAppConfigurationEntry("KafkaClient")[0];
handler.configure(configs,
OAuthBearerLoginModule.OAUTHBEARER_MECHANISM,
Collections.singletonList(kafkaClient));
}
protected String createBase64JsonJwtSection(Consumer<ObjectNode> c) { protected String createBase64JsonJwtSection(Consumer<ObjectNode> c) {
String json = createJsonJwtSection(c); String json = createJsonJwtSection(c);
@ -212,4 +195,11 @@ public abstract class OAuthBearerTest {
return jwk; return jwk;
} }
protected String createAccessKey(String header, String payload, String signature) {
Base64.Encoder enc = Base64.getEncoder();
header = enc.encodeToString(Utils.utf8(header));
payload = enc.encodeToString(Utils.utf8(payload));
signature = enc.encodeToString(Utils.utf8(signature));
return String.format("%s.%s.%s", header, payload, signature);
}
} }

View File

@ -24,11 +24,12 @@ import org.apache.kafka.common.config.ConfigException;
import org.apache.kafka.common.config.SaslConfigs; import org.apache.kafka.common.config.SaslConfigs;
import org.apache.kafka.common.config.SslConfigs; import org.apache.kafka.common.config.SslConfigs;
import org.apache.kafka.common.config.types.Password; import org.apache.kafka.common.config.types.Password;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenRetriever; import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenRetrieverFactory;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidatorFactory;
import org.apache.kafka.common.security.oauthbearer.internals.secured.CloseableVerificationKeyResolver; import org.apache.kafka.common.security.oauthbearer.internals.secured.CloseableVerificationKeyResolver;
import org.apache.kafka.common.security.oauthbearer.internals.secured.DefaultJwtRetriever;
import org.apache.kafka.common.security.oauthbearer.internals.secured.DefaultJwtValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.JwtRetriever;
import org.apache.kafka.common.security.oauthbearer.internals.secured.JwtValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.VerificationKeyResolverFactory; import org.apache.kafka.common.security.oauthbearer.internals.secured.VerificationKeyResolverFactory;
import org.apache.kafka.common.utils.Exit; import org.apache.kafka.common.utils.Exit;
@ -139,9 +140,11 @@ public class OAuthCompatibilityTool {
{ {
// Client side... // Client side...
try (AccessTokenRetriever atr = AccessTokenRetrieverFactory.create(configs, jaasConfigs)) { try (JwtRetriever atr = new DefaultJwtRetriever(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, jaasConfigs)) {
atr.init(); atr.init();
AccessTokenValidator atv = AccessTokenValidatorFactory.create(configs);
try (JwtValidator atv = new DefaultJwtValidator(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM)) {
atv.init();
System.out.println("PASSED 1/5: client configuration"); System.out.println("PASSED 1/5: client configuration");
accessToken = atr.retrieve(); accessToken = atr.retrieve();
@ -151,18 +154,22 @@ public class OAuthCompatibilityTool {
System.out.println("PASSED 3/5: client JWT validation"); System.out.println("PASSED 3/5: client JWT validation");
} }
} }
}
{ {
// Broker side... // Broker side...
try (CloseableVerificationKeyResolver vkr = VerificationKeyResolverFactory.create(configs, jaasConfigs)) { try (CloseableVerificationKeyResolver vkr = VerificationKeyResolverFactory.create(configs, jaasConfigs)) {
vkr.init(); vkr.init();
AccessTokenValidator atv = AccessTokenValidatorFactory.create(configs, vkr);
try (JwtValidator atv = new DefaultJwtValidator(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, vkr)) {
atv.init();
System.out.println("PASSED 4/5: broker configuration"); System.out.println("PASSED 4/5: broker configuration");
atv.validate(accessToken); atv.validate(accessToken);
System.out.println("PASSED 5/5: broker JWT validation"); System.out.println("PASSED 5/5: broker JWT validation");
} }
} }
}
System.out.println("SUCCESS"); System.out.println("SUCCESS");
Exit.exit(0); Exit.exit(0);