KAFKA-18847: Refactor OAuth layer to improve reusability 1/N (#19622)
CI / build (push) Waiting to run Details

Rename `AccessTokenRetriever` and `AccessTokenValidator` to
`JwtRetriever` and `JwtValidator`, respectively. Also converting the
factory pattern classes `AccessTokenRetrieverFactory` and
`AccessTokenValidatorFactory` into delegate/wrapper classes
`DefaultJwtRetriever` and `DefaultJwtValidator`, respectively.

These are all internal changes, no configuration, user APIs, RPCs, etc.
were changed.

Reviewers: Manikumar Reddy <manikumar@confluent.io>, Ken Huang
 <s7133700@gmail.com>, Lianet Magrans <lmagrans@confluent.io>

---------

Co-authored-by: Ken Huang <s7133700@gmail.com>
This commit is contained in:
Kirk True 2025-05-13 09:35:20 -07:00 committed by GitHub
parent c16c240bd1
commit c60c83aaba
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 598 additions and 476 deletions

View File

@ -24,12 +24,13 @@ import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
import org.apache.kafka.common.security.auth.SaslExtensions;
import org.apache.kafka.common.security.auth.SaslExtensionsCallback;
import org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerClientInitialResponse;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenRetriever;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenRetrieverFactory;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidatorFactory;
import org.apache.kafka.common.security.oauthbearer.internals.secured.DefaultJwtRetriever;
import org.apache.kafka.common.security.oauthbearer.internals.secured.DefaultJwtValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.JaasOptionsUtils;
import org.apache.kafka.common.security.oauthbearer.internals.secured.JwtRetriever;
import org.apache.kafka.common.security.oauthbearer.internals.secured.JwtValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.ValidateException;
import org.apache.kafka.common.utils.Utils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -179,55 +180,48 @@ public class OAuthBearerLoginCallbackHandler implements AuthenticateCallbackHand
private Map<String, Object> moduleOptions;
private AccessTokenRetriever accessTokenRetriever;
private JwtRetriever jwtRetriever;
private AccessTokenValidator accessTokenValidator;
private boolean isInitialized = false;
private JwtValidator jwtValidator;
@Override
public void configure(Map<String, ?> configs, String saslMechanism, List<AppConfigurationEntry> jaasConfigEntries) {
moduleOptions = JaasOptionsUtils.getOptions(saslMechanism, jaasConfigEntries);
AccessTokenRetriever accessTokenRetriever = AccessTokenRetrieverFactory.create(configs, saslMechanism, moduleOptions);
AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(configs, saslMechanism);
init(accessTokenRetriever, accessTokenValidator);
}
public void init(AccessTokenRetriever accessTokenRetriever, AccessTokenValidator accessTokenValidator) {
this.accessTokenRetriever = accessTokenRetriever;
this.accessTokenValidator = accessTokenValidator;
try {
this.accessTokenRetriever.init();
} catch (IOException e) {
throw new KafkaException("The OAuth login configuration encountered an error when initializing the AccessTokenRetriever", e);
}
isInitialized = true;
Map<String, Object> moduleOptions = JaasOptionsUtils.getOptions(saslMechanism, jaasConfigEntries);
JwtRetriever jwtRetriever = new DefaultJwtRetriever(configs, saslMechanism, moduleOptions);
JwtValidator jwtValidator = new DefaultJwtValidator(configs, saslMechanism);
init(moduleOptions, jwtRetriever, jwtValidator);
}
/*
* Package-visible for testing.
*/
void init(Map<String, Object> moduleOptions, JwtRetriever jwtRetriever, JwtValidator jwtValidator) {
this.moduleOptions = moduleOptions;
this.jwtRetriever = jwtRetriever;
this.jwtValidator = jwtValidator;
AccessTokenRetriever getAccessTokenRetriever() {
return accessTokenRetriever;
}
try {
this.jwtRetriever.init();
} catch (IOException e) {
throw new KafkaException("The OAuth login callback encountered an error when initializing the JwtRetriever", e);
}
@Override
public void close() {
if (accessTokenRetriever != null) {
try {
this.accessTokenRetriever.close();
} catch (IOException e) {
log.warn("The OAuth login configuration encountered an error when closing the AccessTokenRetriever", e);
}
try {
this.jwtValidator.init();
} catch (IOException e) {
throw new KafkaException("The OAuth login callback encountered an error when initializing the JwtValidator", e);
}
}
@Override
public void close() {
Utils.closeQuietly(jwtRetriever, "JWT retriever");
Utils.closeQuietly(jwtValidator, "JWT validator");
}
@Override
public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
checkInitialized();
checkConfigured();
for (Callback callback : callbacks) {
if (callback instanceof OAuthBearerTokenCallback) {
@ -241,11 +235,11 @@ public class OAuthBearerLoginCallbackHandler implements AuthenticateCallbackHand
}
private void handleTokenCallback(OAuthBearerTokenCallback callback) throws IOException {
checkInitialized();
String accessToken = accessTokenRetriever.retrieve();
checkConfigured();
String accessToken = jwtRetriever.retrieve();
try {
OAuthBearerToken token = accessTokenValidator.validate(accessToken);
OAuthBearerToken token = jwtValidator.validate(accessToken);
callback.token(token);
} catch (ValidateException e) {
log.warn(e.getMessage(), e);
@ -254,7 +248,7 @@ public class OAuthBearerLoginCallbackHandler implements AuthenticateCallbackHand
}
private void handleExtensionsCallback(SaslExtensionsCallback callback) {
checkInitialized();
checkConfigured();
Map<String, String> extensions = new HashMap<>();
@ -286,9 +280,9 @@ public class OAuthBearerLoginCallbackHandler implements AuthenticateCallbackHand
callback.extensions(saslExtensions);
}
private void checkInitialized() {
if (!isInitialized)
throw new IllegalStateException(String.format("To use %s, first call the configure or init method", getClass().getSimpleName()));
private void checkConfigured() {
if (moduleOptions == null || jwtRetriever == null || jwtValidator == null)
throw new IllegalStateException(String.format("To use %s, first call the configure method", getClass().getSimpleName()));
}
}

View File

@ -19,13 +19,14 @@ package org.apache.kafka.common.security.oauthbearer;
import org.apache.kafka.common.KafkaException;
import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidatorFactory;
import org.apache.kafka.common.security.oauthbearer.internals.secured.CloseableVerificationKeyResolver;
import org.apache.kafka.common.security.oauthbearer.internals.secured.DefaultJwtValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.JaasOptionsUtils;
import org.apache.kafka.common.security.oauthbearer.internals.secured.JwtValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.RefreshingHttpsJwksVerificationKeyResolver;
import org.apache.kafka.common.security.oauthbearer.internals.secured.ValidateException;
import org.apache.kafka.common.security.oauthbearer.internals.secured.VerificationKeyResolverFactory;
import org.apache.kafka.common.utils.Utils;
import org.jose4j.jws.JsonWebSignature;
import org.jose4j.jwx.JsonWebStructure;
@ -119,9 +120,7 @@ public class OAuthBearerValidatorCallbackHandler implements AuthenticateCallback
private CloseableVerificationKeyResolver verificationKeyResolver;
private AccessTokenValidator accessTokenValidator;
private boolean isInitialized = false;
private JwtValidator jwtValidator;
@Override
public void configure(Map<String, ?> configs, String saslMechanism, List<AppConfigurationEntry> jaasConfigEntries) {
@ -135,37 +134,39 @@ public class OAuthBearerValidatorCallbackHandler implements AuthenticateCallback
new RefCountingVerificationKeyResolver(VerificationKeyResolverFactory.create(configs, saslMechanism, moduleOptions)));
}
AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(configs, saslMechanism, verificationKeyResolver);
init(verificationKeyResolver, accessTokenValidator);
JwtValidator jwtValidator = new DefaultJwtValidator(configs, saslMechanism, verificationKeyResolver);
init(verificationKeyResolver, jwtValidator);
}
public void init(CloseableVerificationKeyResolver verificationKeyResolver, AccessTokenValidator accessTokenValidator) {
/*
* Package-visible for testing.
*/
void init(CloseableVerificationKeyResolver verificationKeyResolver, JwtValidator jwtValidator) {
this.verificationKeyResolver = verificationKeyResolver;
this.accessTokenValidator = accessTokenValidator;
this.jwtValidator = jwtValidator;
try {
verificationKeyResolver.init();
} catch (Exception e) {
throw new KafkaException("The OAuth validator configuration encountered an error when initializing the VerificationKeyResolver", e);
throw new KafkaException("The OAuth validator callback encountered an error when initializing the VerificationKeyResolver", e);
}
isInitialized = true;
try {
jwtValidator.init();
} catch (IOException e) {
throw new KafkaException("The OAuth validator callback encountered an error when initializing the JwtValidator", e);
}
}
@Override
public void close() {
if (verificationKeyResolver != null) {
try {
verificationKeyResolver.close();
} catch (Exception e) {
log.error(e.getMessage(), e);
}
}
Utils.closeQuietly(jwtValidator, "JWT validator");
Utils.closeQuietly(verificationKeyResolver, "JWT verification key resolver");
}
@Override
public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
checkInitialized();
checkConfigured();
for (Callback callback : callbacks) {
if (callback instanceof OAuthBearerValidatorCallback) {
@ -179,12 +180,12 @@ public class OAuthBearerValidatorCallbackHandler implements AuthenticateCallback
}
private void handleValidatorCallback(OAuthBearerValidatorCallback callback) {
checkInitialized();
checkConfigured();
OAuthBearerToken token;
try {
token = accessTokenValidator.validate(callback.tokenValue());
token = jwtValidator.validate(callback.tokenValue());
callback.token(token);
} catch (ValidateException e) {
log.warn(e.getMessage(), e);
@ -193,14 +194,14 @@ public class OAuthBearerValidatorCallbackHandler implements AuthenticateCallback
}
private void handleExtensionsValidatorCallback(OAuthBearerExtensionsValidatorCallback extensionsValidatorCallback) {
checkInitialized();
checkConfigured();
extensionsValidatorCallback.inputExtensions().map().forEach((extensionName, v) -> extensionsValidatorCallback.valid(extensionName));
}
private void checkInitialized() {
if (!isInitialized)
throw new IllegalStateException(String.format("To use %s, first call the configure or init method", getClass().getSimpleName()));
private void checkConfigured() {
if (verificationKeyResolver == null || jwtValidator == null)
throw new IllegalStateException(String.format("To use %s, first call the configure method", getClass().getSimpleName()));
}
/**

View File

@ -1,73 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.common.security.oauthbearer.internals.secured;
import org.jose4j.keys.resolvers.VerificationKeyResolver;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_CLOCK_SKEW_SECONDS;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_EXPECTED_AUDIENCE;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_EXPECTED_ISSUER;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_SCOPE_CLAIM_NAME;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_SUB_CLAIM_NAME;
public class AccessTokenValidatorFactory {
public static AccessTokenValidator create(Map<String, ?> configs) {
return create(configs, (String) null);
}
public static AccessTokenValidator create(Map<String, ?> configs, String saslMechanism) {
ConfigurationUtils cu = new ConfigurationUtils(configs, saslMechanism);
String scopeClaimName = cu.get(SASL_OAUTHBEARER_SCOPE_CLAIM_NAME);
String subClaimName = cu.get(SASL_OAUTHBEARER_SUB_CLAIM_NAME);
return new LoginAccessTokenValidator(scopeClaimName, subClaimName);
}
public static AccessTokenValidator create(Map<String, ?> configs,
VerificationKeyResolver verificationKeyResolver) {
return create(configs, null, verificationKeyResolver);
}
public static AccessTokenValidator create(Map<String, ?> configs,
String saslMechanism,
VerificationKeyResolver verificationKeyResolver) {
ConfigurationUtils cu = new ConfigurationUtils(configs, saslMechanism);
Set<String> expectedAudiences = null;
List<String> l = cu.get(SASL_OAUTHBEARER_EXPECTED_AUDIENCE);
if (l != null)
expectedAudiences = Set.copyOf(l);
Integer clockSkew = cu.validateInteger(SASL_OAUTHBEARER_CLOCK_SKEW_SECONDS, false);
String expectedIssuer = cu.validateString(SASL_OAUTHBEARER_EXPECTED_ISSUER, false);
String scopeClaimName = cu.validateString(SASL_OAUTHBEARER_SCOPE_CLAIM_NAME);
String subClaimName = cu.validateString(SASL_OAUTHBEARER_SUB_CLAIM_NAME);
return new ValidatorAccessTokenValidator(clockSkew,
expectedAudiences,
expectedIssuer,
verificationKeyResolver,
scopeClaimName,
subClaimName);
}
}

View File

@ -38,7 +38,7 @@ import java.util.Set;
import static org.jose4j.jwa.AlgorithmConstraints.DISALLOW_NONE;
/**
* ValidatorAccessTokenValidator is an implementation of {@link AccessTokenValidator} that is used
* {@code BrokerJwtValidator} is an implementation of {@link JwtValidator} that is used
* by the broker to perform more extensive validation of the JWT access token that is received
* from the client, but ultimately from posting the client credentials to the OAuth/OIDC provider's
* token endpoint.
@ -62,9 +62,9 @@ import static org.jose4j.jwa.AlgorithmConstraints.DISALLOW_NONE;
* </ol>
*/
public class ValidatorAccessTokenValidator implements AccessTokenValidator {
public class BrokerJwtValidator implements JwtValidator {
private static final Logger log = LoggerFactory.getLogger(ValidatorAccessTokenValidator.class);
private static final Logger log = LoggerFactory.getLogger(BrokerJwtValidator.class);
private final JwtConsumer jwtConsumer;
@ -73,7 +73,7 @@ public class ValidatorAccessTokenValidator implements AccessTokenValidator {
private final String subClaimName;
/**
* Creates a new ValidatorAccessTokenValidator that will be used by the broker for more
* Creates a new {@code BrokerJwtValidator} that will be used by the broker for more
* thorough validation of the JWT.
*
* @param clockSkew The optional value (in seconds) to allow for differences
@ -112,12 +112,12 @@ public class ValidatorAccessTokenValidator implements AccessTokenValidator {
* @see VerificationKeyResolver
*/
public ValidatorAccessTokenValidator(Integer clockSkew,
Set<String> expectedAudiences,
String expectedIssuer,
VerificationKeyResolver verificationKeyResolver,
String scopeClaimName,
String subClaimName) {
public BrokerJwtValidator(Integer clockSkew,
Set<String> expectedAudiences,
String expectedIssuer,
VerificationKeyResolver verificationKeyResolver,
String scopeClaimName,
String subClaimName) {
final JwtConsumerBuilder jwtConsumerBuilder = new JwtConsumerBuilder();
if (clockSkew != null)

View File

@ -33,7 +33,7 @@ import static org.apache.kafka.common.config.SaslConfigs.DEFAULT_SASL_OAUTHBEARE
import static org.apache.kafka.common.config.SaslConfigs.DEFAULT_SASL_OAUTHBEARER_SUB_CLAIM_NAME;
/**
* LoginAccessTokenValidator is an implementation of {@link AccessTokenValidator} that is used
* {@code ClientJwtValidator} is an implementation of {@link JwtValidator} that is used
* by the client to perform some rudimentary validation of the JWT access token that is received
* as part of the response from posting the client credentials to the OAuth/OIDC provider's
* token endpoint.
@ -46,13 +46,13 @@ import static org.apache.kafka.common.config.SaslConfigs.DEFAULT_SASL_OAUTHBEARE
* <a href="https://tools.ietf.org/html/rfc6750#section-2.1">RFC 6750 Section 2.1</a>
* </li>
* <li>Basic conversion of the token into an in-memory map</li>
* <li>Presence of scope, <code>exp</code>, subject, and <code>iat</code> claims</li>
* <li>Presence of <code>scope</code>, <code>exp</code>, <code>subject</code>, and <code>iat</code> claims</li>
* </ol>
*/
public class LoginAccessTokenValidator implements AccessTokenValidator {
public class ClientJwtValidator implements JwtValidator {
private static final Logger log = LoggerFactory.getLogger(LoginAccessTokenValidator.class);
private static final Logger log = LoggerFactory.getLogger(ClientJwtValidator.class);
public static final String EXPIRATION_CLAIM_NAME = "exp";
@ -63,14 +63,14 @@ public class LoginAccessTokenValidator implements AccessTokenValidator {
private final String subClaimName;
/**
* Creates a new LoginAccessTokenValidator that will be used by the client for lightweight
* Creates a new {@code ClientJwtValidator} that will be used by the client for lightweight
* validation of the JWT.
*
* @param scopeClaimName Name of the scope claim to use; must be non-<code>null</code>
* @param subClaimName Name of the subject claim to use; must be non-<code>null</code>
*/
public LoginAccessTokenValidator(String scopeClaimName, String subClaimName) {
public ClientJwtValidator(String scopeClaimName, String subClaimName) {
this.scopeClaimName = ClaimValidationUtils.validateClaimNameOverride(DEFAULT_SASL_OAUTHBEARER_SCOPE_CLAIM_NAME, scopeClaimName);
this.subClaimName = ClaimValidationUtils.validateClaimNameOverride(DEFAULT_SASL_OAUTHBEARER_SUB_CLAIM_NAME, subClaimName);
}

View File

@ -18,7 +18,9 @@
package org.apache.kafka.common.security.oauthbearer.internals.secured;
import org.apache.kafka.common.config.SaslConfigs;
import org.apache.kafka.common.utils.Utils;
import java.io.IOException;
import java.net.URL;
import java.util.Locale;
import java.util.Map;
@ -36,32 +38,33 @@ import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallb
import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler.CLIENT_SECRET_CONFIG;
import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler.SCOPE_CONFIG;
public class AccessTokenRetrieverFactory {
/**
* {@code DefaultJwtRetriever} instantiates and delegates {@link JwtRetriever} API calls to an embedded implementation
* based on configuration. If {@link SaslConfigs#SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL} is configured with a
* {@code file}-based URL, a {@link FileJwtRetriever} is created and the JWT is expected be contained in the file
* specified. Otherwise, it's assumed to be an HTTP/HTTPS-based URL, so an {@link HttpJwtRetriever} is created.
*/
public class DefaultJwtRetriever implements JwtRetriever {
/**
* Create an {@link AccessTokenRetriever} from the given SASL and JAAS configuration.
*
* <b>Note</b>: the returned <code>AccessTokenRetriever</code> is <em>not</em> initialized
* here and must be done by the caller prior to use.
*
* @param configs SASL configuration
* @param jaasConfig JAAS configuration
*
* @return Non-<code>null</code> {@link AccessTokenRetriever}
*/
private final Map<String, ?> configs;
private final String saslMechanism;
private final Map<String, Object> jaasConfig;
public static AccessTokenRetriever create(Map<String, ?> configs, Map<String, Object> jaasConfig) {
return create(configs, null, jaasConfig);
private JwtRetriever delegate;
public DefaultJwtRetriever(Map<String, ?> configs, String saslMechanism, Map<String, Object> jaasConfig) {
this.configs = configs;
this.saslMechanism = saslMechanism;
this.jaasConfig = jaasConfig;
}
public static AccessTokenRetriever create(Map<String, ?> configs,
String saslMechanism,
Map<String, Object> jaasConfig) {
@Override
public void init() throws IOException {
ConfigurationUtils cu = new ConfigurationUtils(configs, saslMechanism);
URL tokenEndpointUrl = cu.validateUrl(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL);
if (tokenEndpointUrl.getProtocol().toLowerCase(Locale.ROOT).equals("file")) {
return new FileTokenRetriever(cu.validateFile(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL));
delegate = new FileJwtRetriever(cu.validateFile(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL));
} else {
JaasOptionsUtils jou = new JaasOptionsUtils(jaasConfig);
String clientId = jou.validateString(CLIENT_ID_CONFIG);
@ -75,7 +78,7 @@ public class AccessTokenRetrieverFactory {
boolean urlencodeHeader = validateUrlencodeHeader(cu);
return new HttpAccessTokenRetriever(clientId,
delegate = new HttpJwtRetriever(clientId,
clientSecret,
scope,
sslSocketFactory,
@ -86,6 +89,21 @@ public class AccessTokenRetrieverFactory {
cu.validateInteger(SASL_LOGIN_READ_TIMEOUT_MS, false),
urlencodeHeader);
}
delegate.init();
}
@Override
public String retrieve() throws IOException {
if (delegate == null)
throw new IllegalStateException("JWT retriever delegate is null; please call init() first");
return delegate.retrieve();
}
@Override
public void close() throws IOException {
Utils.closeQuietly(delegate, "JWT retriever delegate");
}
/**
@ -96,10 +114,10 @@ public class AccessTokenRetrieverFactory {
* <p/>
*
* This utility method ensures that we have a non-{@code null} value to use in the
* {@link HttpAccessTokenRetriever} constructor.
* {@link HttpJwtRetriever} constructor.
*/
static boolean validateUrlencodeHeader(ConfigurationUtils configurationUtils) {
Boolean urlencodeHeader = configurationUtils.validateBoolean(SASL_OAUTHBEARER_HEADER_URLENCODE, false);
Boolean urlencodeHeader = configurationUtils.get(SASL_OAUTHBEARER_HEADER_URLENCODE);
if (urlencodeHeader != null)
return urlencodeHeader;
@ -107,4 +125,7 @@ public class AccessTokenRetrieverFactory {
return DEFAULT_SASL_OAUTHBEARER_HEADER_URLENCODE;
}
JwtRetriever delegate() {
return delegate;
}
}

View File

@ -0,0 +1,108 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.common.security.oauthbearer.internals.secured;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
import org.apache.kafka.common.utils.Utils;
import org.jose4j.keys.resolvers.VerificationKeyResolver;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_CLOCK_SKEW_SECONDS;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_EXPECTED_AUDIENCE;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_EXPECTED_ISSUER;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_SCOPE_CLAIM_NAME;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_SUB_CLAIM_NAME;
/**
* This {@link JwtValidator} uses the delegation approach, instantiating and delegating calls to a
* more concrete implementation. The underlying implementation is determined by the presence/absence
* of the {@link VerificationKeyResolver}: if it's present, a {@link BrokerJwtValidator} is
* created, otherwise a {@link ClientJwtValidator} is created.
*/
public class DefaultJwtValidator implements JwtValidator {
private final Map<String, ?> configs;
private final String saslMechanism;
private final Optional<VerificationKeyResolver> verificationKeyResolver;
private JwtValidator delegate;
public DefaultJwtValidator(Map<String, ?> configs, String saslMechanism) {
this.configs = configs;
this.saslMechanism = saslMechanism;
this.verificationKeyResolver = Optional.empty();
}
public DefaultJwtValidator(Map<String, ?> configs,
String saslMechanism,
VerificationKeyResolver verificationKeyResolver) {
this.configs = configs;
this.saslMechanism = saslMechanism;
this.verificationKeyResolver = Optional.of(verificationKeyResolver);
}
@Override
public void init() throws IOException {
ConfigurationUtils cu = new ConfigurationUtils(configs, saslMechanism);
if (verificationKeyResolver.isPresent()) {
List<String> expectedAudiencesList = cu.get(SASL_OAUTHBEARER_EXPECTED_AUDIENCE);
Set<String> expectedAudiences = expectedAudiencesList != null ? Set.copyOf(expectedAudiencesList) : null;
Integer clockSkew = cu.validateInteger(SASL_OAUTHBEARER_CLOCK_SKEW_SECONDS, false);
String expectedIssuer = cu.validateString(SASL_OAUTHBEARER_EXPECTED_ISSUER, false);
String scopeClaimName = cu.validateString(SASL_OAUTHBEARER_SCOPE_CLAIM_NAME);
String subClaimName = cu.validateString(SASL_OAUTHBEARER_SUB_CLAIM_NAME);
delegate = new BrokerJwtValidator(clockSkew,
expectedAudiences,
expectedIssuer,
verificationKeyResolver.get(),
scopeClaimName,
subClaimName);
} else {
String scopeClaimName = cu.get(SASL_OAUTHBEARER_SCOPE_CLAIM_NAME);
String subClaimName = cu.get(SASL_OAUTHBEARER_SUB_CLAIM_NAME);
delegate = new ClientJwtValidator(scopeClaimName, subClaimName);
}
delegate.init();
}
@Override
public OAuthBearerToken validate(String accessToken) throws ValidateException {
if (delegate == null)
throw new IllegalStateException("JWT validator delegate is null; please call init() first");
return delegate.validate(accessToken);
}
@Override
public void close() throws IOException {
Utils.closeQuietly(delegate, "JWT validator delegate");
}
JwtValidator delegate() {
return delegate;
}
}

View File

@ -23,19 +23,19 @@ import java.io.IOException;
import java.nio.file.Path;
/**
* <code>FileTokenRetriever</code> is an {@link AccessTokenRetriever} that will load the contents,
* interpreting them as a JWT access key in the serialized form.
* <code>FileJwtRetriever</code> is an {@link JwtRetriever} that will load the contents
* of a file, interpreting them as a JWT access key in the serialized form.
*
* @see AccessTokenRetriever
* @see JwtRetriever
*/
public class FileTokenRetriever implements AccessTokenRetriever {
public class FileJwtRetriever implements JwtRetriever {
private final Path accessTokenFile;
private String accessToken;
public FileTokenRetriever(Path accessTokenFile) {
public FileJwtRetriever(Path accessTokenFile) {
this.accessTokenFile = accessTokenFile;
}

View File

@ -49,22 +49,14 @@ import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLSocketFactory;
/**
* <code>HttpAccessTokenRetriever</code> is an {@link AccessTokenRetriever} that will
* communicate with an OAuth/OIDC provider directly via HTTP to post client credentials
* <code>HttpJwtRetriever</code> is a {@link JwtRetriever} that will communicate with an OAuth/OIDC
* provider directly via HTTP to post client credentials
* ({@link OAuthBearerLoginCallbackHandler#CLIENT_ID_CONFIG}/{@link OAuthBearerLoginCallbackHandler#CLIENT_SECRET_CONFIG})
* to a publicized token endpoint URL
* ({@link SaslConfigs#SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL}).
*
* @see AccessTokenRetriever
* @see OAuthBearerLoginCallbackHandler#CLIENT_ID_CONFIG
* @see OAuthBearerLoginCallbackHandler#CLIENT_SECRET_CONFIG
* @see OAuthBearerLoginCallbackHandler#SCOPE_CONFIG
* @see SaslConfigs#SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL
* to a publicized token endpoint URL ({@link SaslConfigs#SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL}).
*/
public class HttpJwtRetriever implements JwtRetriever {
public class HttpAccessTokenRetriever implements AccessTokenRetriever {
private static final Logger log = LoggerFactory.getLogger(HttpAccessTokenRetriever.class);
private static final Logger log = LoggerFactory.getLogger(HttpJwtRetriever.class);
private static final Set<Integer> UNRETRYABLE_HTTP_CODES;
@ -117,16 +109,16 @@ public class HttpAccessTokenRetriever implements AccessTokenRetriever {
private final boolean urlencodeHeader;
public HttpAccessTokenRetriever(String clientId,
String clientSecret,
String scope,
SSLSocketFactory sslSocketFactory,
String tokenEndpointUrl,
long loginRetryBackoffMs,
long loginRetryBackoffMaxMs,
Integer loginConnectTimeoutMs,
Integer loginReadTimeoutMs,
boolean urlencodeHeader) {
public HttpJwtRetriever(String clientId,
String clientSecret,
String scope,
SSLSocketFactory sslSocketFactory,
String tokenEndpointUrl,
long loginRetryBackoffMs,
long loginRetryBackoffMaxMs,
Integer loginConnectTimeoutMs,
Integer loginReadTimeoutMs,
boolean urlencodeHeader) {
this.clientId = Objects.requireNonNull(clientId);
this.clientSecret = Objects.requireNonNull(clientSecret);
this.scope = scope;

View File

@ -22,8 +22,8 @@ import java.io.IOException;
public interface Initable {
/**
* Lifecycle method to perform any one-time initialization of the retriever. This must
* be performed by the caller to ensure the correct state before methods are invoked.
* Lifecycle method to perform any one-time initialization of a given resource. This must
* be invoked by the caller to ensure the correct state before methods are invoked.
*
* @throws IOException Thrown on errors related to IO during initialization
*/
@ -31,5 +31,4 @@ public interface Initable {
default void init() throws IOException {
// This method left intentionally blank.
}
}

View File

@ -21,20 +21,20 @@ import java.io.Closeable;
import java.io.IOException;
/**
* An <code>AccessTokenRetriever</code> is the internal API by which the login module will
* A <code>JwtRetriever</code> is the internal API by which the login module will
* retrieve an access token for use in authorization by the broker. The implementation may
* involve authentication to a remote system, or it can be as simple as loading the contents
* of a file or configuration setting.
*
* <i>Retrieval</i> is a separate concern from <i>validation</i>, so it isn't necessary for
* the <code>AccessTokenRetriever</code> implementation to validate the integrity of the JWT
* the <code>JwtRetriever</code> implementation to validate the integrity of the JWT
* access token.
*
* @see HttpAccessTokenRetriever
* @see FileTokenRetriever
* @see HttpJwtRetriever
* @see FileJwtRetriever
*/
public interface AccessTokenRetriever extends Initable, Closeable {
public interface JwtRetriever extends Initable, Closeable {
/**
* Retrieves a JWT access token in its serialized three-part form. The implementation

View File

@ -19,8 +19,11 @@ package org.apache.kafka.common.security.oauthbearer.internals.secured;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
import java.io.Closeable;
import java.io.IOException;
/**
* An instance of <code>AccessTokenValidator</code> acts as a function object that, given an access
* An instance of <code>JwtValidator</code> acts as a function object that, given an access
* token in base-64 encoded JWT format, can parse the data, perform validation, and construct an
* {@link OAuthBearerToken} for use by the caller.
*
@ -40,13 +43,12 @@ import org.apache.kafka.common.security.oauthbearer.OAuthBearerToken;
* <li><a href="https://datatracker.ietf.org/doc/html/draft-ietf-oauth-access-token-jwt">RFC 6750, Section 2.1</a></li>
* </ul>
*
* @see LoginAccessTokenValidator A basic AccessTokenValidator used by client-side login
* authentication
* @see ValidatorAccessTokenValidator A more robust AccessTokenValidator that is used on the broker
* to validate the token's contents and verify the signature
* @see ClientJwtValidator A basic JwtValidator used by client-side login authentication
* @see BrokerJwtValidator A more robust JwtValidator that is used on the broker to validate the token's
* contents and verify the signature
*/
public interface AccessTokenValidator {
public interface JwtValidator extends Initable, Closeable {
/**
* Accepts an OAuth JWT access token in base-64 encoded format, validates, and returns an
@ -61,4 +63,10 @@ public interface AccessTokenValidator {
OAuthBearerToken validate(String accessToken) throws ValidateException;
/**
* Closes any resources that were initialized by {@link #init()}.
*/
default void close() throws IOException {
// Do nothing...
}
}

View File

@ -49,12 +49,12 @@ import java.util.concurrent.locks.ReentrantReadWriteLock;
* This instance is created and provided to the
* {@link org.jose4j.keys.resolvers.HttpsJwksVerificationKeyResolver} that is used when using
* an HTTP-/HTTPS-based {@link org.jose4j.keys.resolvers.VerificationKeyResolver}, which is then
* provided to the {@link ValidatorAccessTokenValidator} to use in validating the signature of
* provided to the {@link BrokerJwtValidator} to use in validating the signature of
* a JWT.
*
* @see org.jose4j.keys.resolvers.HttpsJwksVerificationKeyResolver
* @see org.jose4j.keys.resolvers.VerificationKeyResolver
* @see ValidatorAccessTokenValidator
* @see BrokerJwtValidator
*/
public final class RefreshingHttpsJwks implements Initable, Closeable {

View File

@ -27,7 +27,7 @@ import javax.security.auth.callback.Callback;
* processing of a {@link javax.security.auth.callback.CallbackHandler#handle(Callback[])}.
* This error, however, is not thrown from that method directly.
*
* @see AccessTokenValidator#validate(String)
* @see JwtValidator#validate(String)
*/
public class ValidateException extends KafkaException {

View File

@ -37,7 +37,7 @@ import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_JWKS_E
public class VerificationKeyResolverFactory {
/**
* Create an {@link AccessTokenRetriever} from the given
* Create a {@link JwtRetriever} from the given
* {@link org.apache.kafka.common.config.SaslConfigs}.
*
* <b>Note</b>: the returned <code>CloseableVerificationKeyResolver</code> is not

View File

@ -21,13 +21,12 @@ import org.apache.kafka.common.config.ConfigException;
import org.apache.kafka.common.security.auth.SaslExtensionsCallback;
import org.apache.kafka.common.security.oauthbearer.internals.OAuthBearerClientInitialResponse;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenBuilder;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenRetriever;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidatorFactory;
import org.apache.kafka.common.security.oauthbearer.internals.secured.FileTokenRetriever;
import org.apache.kafka.common.security.oauthbearer.internals.secured.HttpAccessTokenRetriever;
import org.apache.kafka.common.security.oauthbearer.internals.secured.DefaultJwtRetriever;
import org.apache.kafka.common.security.oauthbearer.internals.secured.DefaultJwtValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.FileJwtRetriever;
import org.apache.kafka.common.security.oauthbearer.internals.secured.JwtRetriever;
import org.apache.kafka.common.security.oauthbearer.internals.secured.JwtValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.OAuthBearerTest;
import org.apache.kafka.common.utils.Utils;
import org.jose4j.jws.AlgorithmIdentifiers;
import org.junit.jupiter.api.AfterEach;
@ -35,9 +34,7 @@ import org.junit.jupiter.api.Test;
import java.io.File;
import java.io.IOException;
import java.util.Base64;
import java.util.Calendar;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.TimeZone;
@ -50,7 +47,6 @@ import static org.apache.kafka.common.config.internals.BrokerSecurityConfigs.ALL
import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler.CLIENT_ID_CONFIG;
import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler.CLIENT_SECRET_CONFIG;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
@ -58,6 +54,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest {
@AfterEach
public void tearDown() throws Exception {
System.clearProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG);
@ -70,9 +67,10 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest {
.jwk(createRsaJwk())
.alg(AlgorithmIdentifiers.RSA_USING_SHA256);
String accessToken = builder.build();
AccessTokenRetriever accessTokenRetriever = () -> accessToken;
OAuthBearerLoginCallbackHandler handler = createHandler(accessTokenRetriever, configs);
JwtRetriever jwtRetriever = () -> accessToken;
JwtValidator jwtValidator = createJwtValidator(configs);
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
handler.init(Map.of(), jwtRetriever, jwtValidator);
try {
OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback();
@ -91,7 +89,6 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest {
@Test
public void testHandleSaslExtensionsCallback() throws Exception {
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, "http://www.example.com");
System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, "http://www.example.com");
Map<String, Object> jaasConfig = new HashMap<>();
@ -100,7 +97,11 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest {
jaasConfig.put("extension_foo", "1");
jaasConfig.put("extension_bar", 2);
jaasConfig.put("EXTENSION_baz", "3");
configureHandler(handler, configs, jaasConfig);
JwtRetriever jwtRetriever = createJwtRetriever(configs, jaasConfig);
JwtValidator jwtValidator = createJwtValidator(configs);
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
handler.init(jaasConfig, jwtRetriever, jwtValidator);
try {
SaslExtensionsCallback callback = new SaslExtensionsCallback();
@ -121,14 +122,17 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest {
public void testHandleSaslExtensionsCallbackWithInvalidExtension() {
String illegalKey = "extension_" + OAuthBearerClientInitialResponse.AUTH_KEY;
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, "http://www.example.com");
System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, "http://www.example.com");
Map<String, Object> jaasConfig = new HashMap<>();
jaasConfig.put(CLIENT_ID_CONFIG, "an ID");
jaasConfig.put(CLIENT_SECRET_CONFIG, "a secret");
jaasConfig.put(illegalKey, "this key isn't allowed per OAuthBearerClientInitialResponse.validateExtensions");
configureHandler(handler, configs, jaasConfig);
JwtRetriever jwtRetriever = createJwtRetriever(configs, jaasConfig);
JwtValidator jwtValidator = createJwtValidator(configs);
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
handler.init(jaasConfig, jwtRetriever, jwtValidator);
try {
SaslExtensionsCallback callback = new SaslExtensionsCallback();
@ -143,10 +147,10 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest {
@Test
public void testInvalidCallbackGeneratesUnsupportedCallbackException() {
Map<String, ?> configs = getSaslConfigs();
JwtRetriever jwtRetriever = () -> "test";
JwtValidator jwtValidator = createJwtValidator(configs);
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
AccessTokenRetriever accessTokenRetriever = () -> "foo";
AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(configs);
handler.init(accessTokenRetriever, accessTokenValidator);
handler.init(Map.of(), jwtRetriever, jwtValidator);
try {
Callback unsupportedCallback = new Callback() { };
@ -166,11 +170,13 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest {
@Test
public void testMissingAccessToken() {
AccessTokenRetriever accessTokenRetriever = () -> {
Map<String, ?> configs = getSaslConfigs();
JwtRetriever jwtRetriever = () -> {
throw new IOException("The token endpoint response access_token value must be non-null");
};
Map<String, ?> configs = getSaslConfigs();
OAuthBearerLoginCallbackHandler handler = createHandler(accessTokenRetriever, configs);
JwtValidator jwtValidator = createJwtValidator(configs);
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
handler.init(Map.of(), jwtRetriever, jwtValidator);
try {
OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback();
@ -196,7 +202,11 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest {
File accessTokenFile = createTempFile(tmpDir, "access-token-", ".json", withNewline);
Map<String, ?> configs = getSaslConfigs();
OAuthBearerLoginCallbackHandler handler = createHandler(new FileTokenRetriever(accessTokenFile.toPath()), configs);
JwtRetriever jwtRetriever = new FileJwtRetriever(accessTokenFile.toPath());
JwtValidator jwtValidator = createJwtValidator(configs);
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
handler.init(Map.of(), jwtRetriever, jwtValidator);
OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback();
try {
handler.handle(new Callback[]{callback});
@ -211,39 +221,15 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest {
@Test
public void testNotConfigured() {
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
assertThrowsWithMessage(IllegalStateException.class, () -> handler.handle(new Callback[] {}), "first call the configure or init method");
}
@Test
public void testConfigureWithAccessTokenFile() throws Exception {
String expected = "{}";
File tmpDir = createTempDir("access-token");
File accessTokenFile = createTempFile(tmpDir, "access-token-", ".json", expected);
System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, accessTokenFile.toURI().toString());
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, accessTokenFile.toURI().toString());
Map<String, Object> jaasConfigs = Collections.emptyMap();
configureHandler(handler, configs, jaasConfigs);
assertInstanceOf(FileTokenRetriever.class, handler.getAccessTokenRetriever());
}
@Test
public void testConfigureWithAccessClientCredentials() {
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, "http://www.example.com");
System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, "http://www.example.com");
Map<String, Object> jaasConfigs = new HashMap<>();
jaasConfigs.put(CLIENT_ID_CONFIG, "an ID");
jaasConfigs.put(CLIENT_SECRET_CONFIG, "a secret");
configureHandler(handler, configs, jaasConfigs);
assertInstanceOf(HttpAccessTokenRetriever.class, handler.getAccessTokenRetriever());
assertThrowsWithMessage(IllegalStateException.class, () -> handler.handle(new Callback[] {}), "first call the configure method");
}
private void testInvalidAccessToken(String accessToken, String expectedMessageSubstring) throws Exception {
Map<String, ?> configs = getSaslConfigs();
OAuthBearerLoginCallbackHandler handler = createHandler(() -> accessToken, configs);
JwtRetriever jwtRetriever = () -> accessToken;
JwtValidator jwtValidator = createJwtValidator(configs);
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
handler.init(Map.of(), jwtRetriever, jwtValidator);
try {
OAuthBearerTokenCallback callback = new OAuthBearerTokenCallback();
@ -260,19 +246,15 @@ public class OAuthBearerLoginCallbackHandlerTest extends OAuthBearerTest {
}
}
private String createAccessKey(String header, String payload, String signature) {
Base64.Encoder enc = Base64.getEncoder();
header = enc.encodeToString(Utils.utf8(header));
payload = enc.encodeToString(Utils.utf8(payload));
signature = enc.encodeToString(Utils.utf8(signature));
return String.format("%s.%s.%s", header, payload, signature);
private static DefaultJwtRetriever createJwtRetriever(Map<String, ?> configs) {
return createJwtRetriever(configs, Map.of());
}
private OAuthBearerLoginCallbackHandler createHandler(AccessTokenRetriever accessTokenRetriever, Map<String, ?> configs) {
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(configs);
handler.init(accessTokenRetriever, accessTokenValidator);
return handler;
private static DefaultJwtRetriever createJwtRetriever(Map<String, ?> configs, Map<String, Object> jaasConfigs) {
return new DefaultJwtRetriever(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, jaasConfigs);
}
private static DefaultJwtValidator createJwtValidator(Map<String, ?> configs) {
return new DefaultJwtValidator(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM);
}
}

View File

@ -17,27 +17,30 @@
package org.apache.kafka.common.security.oauthbearer;
import org.apache.kafka.common.KafkaException;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenBuilder;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidatorFactory;
import org.apache.kafka.common.security.oauthbearer.internals.secured.CloseableVerificationKeyResolver;
import org.apache.kafka.common.security.oauthbearer.internals.secured.DefaultJwtValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.JwtValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.OAuthBearerTest;
import org.apache.kafka.common.utils.Utils;
import org.apache.kafka.common.security.oauthbearer.internals.secured.ValidateException;
import org.jose4j.jws.AlgorithmIdentifiers;
import org.junit.jupiter.api.Test;
import java.io.IOException;
import java.util.Arrays;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import javax.security.auth.callback.Callback;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_EXPECTED_AUDIENCE;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
public class OAuthBearerValidatorCallbackHandlerTest extends OAuthBearerTest {
@ -53,7 +56,10 @@ public class OAuthBearerValidatorCallbackHandlerTest extends OAuthBearerTest {
String accessToken = builder.build();
Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_EXPECTED_AUDIENCE, allAudiences);
OAuthBearerValidatorCallbackHandler handler = createHandler(configs, builder);
CloseableVerificationKeyResolver verificationKeyResolver = createVerificationKeyResolver(builder);
JwtValidator jwtValidator = createJwtValidator(configs, verificationKeyResolver);
OAuthBearerValidatorCallbackHandler handler = new OAuthBearerValidatorCallbackHandler();
handler.init(verificationKeyResolver, jwtValidator);
try {
OAuthBearerValidatorCallback callback = new OAuthBearerValidatorCallback(accessToken);
@ -81,9 +87,68 @@ public class OAuthBearerValidatorCallbackHandlerTest extends OAuthBearerTest {
assertInvalidAccessTokenFails(createAccessKey("{}", "{}", "{}"), substring);
}
@Test
public void testHandlerInitThrowsException() throws IOException {
IOException initError = new IOException("init() error");
AccessTokenBuilder builder = new AccessTokenBuilder()
.alg(AlgorithmIdentifiers.RSA_USING_SHA256);
CloseableVerificationKeyResolver verificationKeyResolver = createVerificationKeyResolver(builder);
JwtValidator jwtValidator = new JwtValidator() {
@Override
public void init() throws IOException {
throw initError;
}
@Override
public OAuthBearerToken validate(String accessToken) throws ValidateException {
return null;
}
};
OAuthBearerValidatorCallbackHandler handler = new OAuthBearerValidatorCallbackHandler();
// An error initializing the JwtValidator should cause OAuthBearerValidatorCallbackHandler.init() to fail.
KafkaException root = assertThrows(
KafkaException.class,
() -> handler.init(verificationKeyResolver, jwtValidator)
);
assertNotNull(root.getCause());
assertEquals(initError, root.getCause());
}
@Test
public void testHandlerCloseDoesNotThrowException() throws IOException {
AccessTokenBuilder builder = new AccessTokenBuilder()
.alg(AlgorithmIdentifiers.RSA_USING_SHA256);
CloseableVerificationKeyResolver verificationKeyResolver = createVerificationKeyResolver(builder);
JwtValidator jwtValidator = new JwtValidator() {
@Override
public void close() throws IOException {
throw new IOException("close() error");
}
@Override
public OAuthBearerToken validate(String accessToken) throws ValidateException {
return null;
}
};
OAuthBearerValidatorCallbackHandler handler = new OAuthBearerValidatorCallbackHandler();
handler.init(verificationKeyResolver, jwtValidator);
// An error closings the JwtValidator should *not* cause OAuthBearerValidatorCallbackHandler.close() to fail.
assertDoesNotThrow(handler::close);
}
private void assertInvalidAccessTokenFails(String accessToken, String expectedMessageSubstring) throws Exception {
AccessTokenBuilder builder = new AccessTokenBuilder()
.alg(AlgorithmIdentifiers.RSA_USING_SHA256);
Map<String, ?> configs = getSaslConfigs();
OAuthBearerValidatorCallbackHandler handler = createHandler(configs, new AccessTokenBuilder());
CloseableVerificationKeyResolver verificationKeyResolver = createVerificationKeyResolver(builder);
JwtValidator jwtValidator = createJwtValidator(configs, verificationKeyResolver);
OAuthBearerValidatorCallbackHandler handler = new OAuthBearerValidatorCallbackHandler();
handler.init(verificationKeyResolver, jwtValidator);
try {
OAuthBearerValidatorCallback callback = new OAuthBearerValidatorCallback(accessToken);
@ -98,22 +163,11 @@ public class OAuthBearerValidatorCallbackHandlerTest extends OAuthBearerTest {
}
}
private OAuthBearerValidatorCallbackHandler createHandler(Map<String, ?> options,
AccessTokenBuilder builder) {
OAuthBearerValidatorCallbackHandler handler = new OAuthBearerValidatorCallbackHandler();
CloseableVerificationKeyResolver verificationKeyResolver = (jws, nestingContext) ->
builder.jwk().getPublicKey();
AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(options, verificationKeyResolver);
handler.init(verificationKeyResolver, accessTokenValidator);
return handler;
private JwtValidator createJwtValidator(Map<String, ?> configs, CloseableVerificationKeyResolver verificationKeyResolver) {
return new DefaultJwtValidator(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, verificationKeyResolver);
}
private String createAccessKey(String header, String payload, String signature) {
Base64.Encoder enc = Base64.getEncoder();
header = enc.encodeToString(Utils.utf8(header));
payload = enc.encodeToString(Utils.utf8(payload));
signature = enc.encodeToString(Utils.utf8(signature));
return String.format("%s.%s.%s", header, payload, signature);
private CloseableVerificationKeyResolver createVerificationKeyResolver(AccessTokenBuilder builder) {
return (jws, nestingContext) -> builder.jwk().getPublicKey();
}
}

View File

@ -1,73 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.common.security.oauthbearer.internals.secured;
import org.apache.kafka.common.KafkaException;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler;
import org.junit.jupiter.api.Test;
import java.io.IOException;
import java.util.Map;
public class AccessTokenValidatorFactoryTest extends OAuthBearerTest {
@Test
public void testConfigureThrowsExceptionOnAccessTokenValidatorInit() {
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
AccessTokenRetriever accessTokenRetriever = new AccessTokenRetriever() {
@Override
public void init() throws IOException {
throw new IOException("My init had an error!");
}
@Override
public String retrieve() {
return "dummy";
}
};
Map<String, ?> configs = getSaslConfigs();
AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(configs);
assertThrowsWithMessage(
KafkaException.class, () -> handler.init(accessTokenRetriever, accessTokenValidator), "encountered an error when initializing");
}
@Test
public void testConfigureThrowsExceptionOnAccessTokenValidatorClose() {
OAuthBearerLoginCallbackHandler handler = new OAuthBearerLoginCallbackHandler();
AccessTokenRetriever accessTokenRetriever = new AccessTokenRetriever() {
@Override
public void close() throws IOException {
throw new IOException("My close had an error!");
}
@Override
public String retrieve() {
return "dummy";
}
};
Map<String, ?> configs = getSaslConfigs();
AccessTokenValidator accessTokenValidator = AccessTokenValidatorFactory.create(configs);
handler.init(accessTokenRetriever, accessTokenValidator);
// Basically asserting this doesn't throw an exception :(
handler.close();
}
}

View File

@ -28,11 +28,11 @@ import java.util.Collections;
import static org.junit.jupiter.api.Assertions.assertEquals;
public class ValidatorAccessTokenValidatorTest extends AccessTokenValidatorTest {
public class BrokerJwtValidatorTest extends JwtValidatorTest {
@Override
protected AccessTokenValidator createAccessTokenValidator(AccessTokenBuilder builder) {
return new ValidatorAccessTokenValidator(30,
protected JwtValidator createJwtValidator(AccessTokenBuilder builder) {
return new BrokerJwtValidator(30,
Collections.emptySet(),
null,
(jws, nestingContext) -> builder.jwk().getKey(),
@ -72,7 +72,7 @@ public class ValidatorAccessTokenValidatorTest extends AccessTokenValidatorTest
.addCustomClaim(subClaimName, subject)
.subjectClaimName(subClaimName)
.subject(null);
AccessTokenValidator validator = createAccessTokenValidator(tokenBuilder);
JwtValidator validator = createJwtValidator(tokenBuilder);
// Validation should succeed (e.g. signature verification) even if sub claim is missing
OAuthBearerToken token = validator.validate(tokenBuilder.build());
@ -82,7 +82,7 @@ public class ValidatorAccessTokenValidatorTest extends AccessTokenValidatorTest
private void testEncryptionAlgorithm(PublicJsonWebKey jwk, String alg) throws Exception {
AccessTokenBuilder builder = new AccessTokenBuilder().jwk(jwk).alg(alg);
AccessTokenValidator validator = createAccessTokenValidator(builder);
JwtValidator validator = createJwtValidator(builder);
String accessToken = builder.build();
OAuthBearerToken token = validator.validate(accessToken);

View File

@ -17,11 +17,11 @@
package org.apache.kafka.common.security.oauthbearer.internals.secured;
public class LoginAccessTokenValidatorTest extends AccessTokenValidatorTest {
public class ClientJwtValidatorTest extends JwtValidatorTest {
@Override
protected AccessTokenValidator createAccessTokenValidator(AccessTokenBuilder builder) {
return new LoginAccessTokenValidator(builder.scopeClaimName(), builder.subjectClaimName());
protected JwtValidator createJwtValidator(AccessTokenBuilder builder) {
return new ClientJwtValidator(builder.scopeClaimName(), builder.subjectClaimName());
}
}

View File

@ -18,6 +18,7 @@
package org.apache.kafka.common.security.oauthbearer.internals.secured;
import org.apache.kafka.common.config.ConfigException;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
@ -26,7 +27,9 @@ import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import java.io.File;
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.Stream;
@ -34,9 +37,13 @@ import static org.apache.kafka.common.config.SaslConfigs.DEFAULT_SASL_OAUTHBEARE
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_HEADER_URLENCODE;
import static org.apache.kafka.common.config.SaslConfigs.SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL;
import static org.apache.kafka.common.config.internals.BrokerSecurityConfigs.ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG;
import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler.CLIENT_ID_CONFIG;
import static org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginCallbackHandler.CLIENT_SECRET_CONFIG;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
public class AccessTokenRetrieverFactoryTest extends OAuthBearerTest {
public class DefaultJwtRetrieverTest extends OAuthBearerTest {
@AfterEach
public void tearDown() throws Exception {
@ -44,7 +51,7 @@ public class AccessTokenRetrieverFactoryTest extends OAuthBearerTest {
}
@Test
public void testConfigureRefreshingFileAccessTokenRetriever() throws Exception {
public void testConfigureRefreshingFileJwtRetriever() throws Exception {
String expected = "{}";
File tmpDir = createTempDir("access-token");
@ -54,31 +61,37 @@ public class AccessTokenRetrieverFactoryTest extends OAuthBearerTest {
Map<String, ?> configs = Collections.singletonMap(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, accessTokenFile.toURI().toString());
Map<String, Object> jaasConfig = Collections.emptyMap();
try (AccessTokenRetriever accessTokenRetriever = AccessTokenRetrieverFactory.create(configs, jaasConfig)) {
accessTokenRetriever.init();
assertEquals(expected, accessTokenRetriever.retrieve());
try (JwtRetriever jwtRetriever = new DefaultJwtRetriever(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, jaasConfig)) {
jwtRetriever.init();
assertEquals(expected, jwtRetriever.retrieve());
}
}
@Test
public void testConfigureRefreshingFileAccessTokenRetrieverWithInvalidDirectory() {
public void testConfigureRefreshingFileJwtRetrieverWithInvalidDirectory() throws IOException {
// Should fail because the parent path doesn't exist.
String file = new File("/tmp/this-directory-does-not-exist/foo.json").toURI().toString();
System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, file);
Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, file);
Map<String, Object> jaasConfig = Collections.emptyMap();
assertThrowsWithMessage(ConfigException.class, () -> AccessTokenRetrieverFactory.create(configs, jaasConfig), "that doesn't exist");
try (JwtRetriever jwtRetriever = new DefaultJwtRetriever(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, jaasConfig)) {
assertThrowsWithMessage(ConfigException.class, jwtRetriever::init, "that doesn't exist");
}
}
@Test
public void testConfigureRefreshingFileAccessTokenRetrieverWithInvalidFile() throws Exception {
public void testConfigureRefreshingFileJwtRetrieverWithInvalidFile() throws Exception {
// Should fail because while the parent path exists, the file itself doesn't.
File tmpDir = createTempDir("this-directory-does-exist");
File accessTokenFile = new File(tmpDir, "this-file-does-not-exist.json");
System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, accessTokenFile.toURI().toString());
Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, accessTokenFile.toURI().toString());
Map<String, Object> jaasConfig = Collections.emptyMap();
assertThrowsWithMessage(ConfigException.class, () -> AccessTokenRetrieverFactory.create(configs, jaasConfig), "that doesn't exist");
try (JwtRetriever jwtRetriever = new DefaultJwtRetriever(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, jaasConfig)) {
assertThrowsWithMessage(ConfigException.class, jwtRetriever::init, "that doesn't exist");
}
}
@Test
@ -87,15 +100,53 @@ public class AccessTokenRetrieverFactoryTest extends OAuthBearerTest {
File tmpDir = createTempDir("not_allowed");
File accessTokenFile = new File(tmpDir, "not_allowed.json");
Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, accessTokenFile.toURI().toString());
assertThrowsWithMessage(ConfigException.class, () -> AccessTokenRetrieverFactory.create(configs, Collections.emptyMap()),
ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG);
try (JwtRetriever jwtRetriever = new DefaultJwtRetriever(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, Collections.emptyMap())) {
assertThrowsWithMessage(ConfigException.class, jwtRetriever::init, ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG);
}
}
@Test
public void testConfigureWithAccessTokenFile() throws Exception {
String expected = "{}";
File tmpDir = createTempDir("access-token");
File accessTokenFile = createTempFile(tmpDir, "access-token-", ".json", expected);
System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, accessTokenFile.toURI().toString());
Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, accessTokenFile.toURI().toString());
DefaultJwtRetriever jwtRetriever = new DefaultJwtRetriever(
configs,
OAuthBearerLoginModule.OAUTHBEARER_MECHANISM,
Map.of()
);
assertDoesNotThrow(jwtRetriever::init);
assertInstanceOf(FileJwtRetriever.class, jwtRetriever.delegate());
}
@Test
public void testConfigureWithAccessClientCredentials() {
Map<String, ?> configs = getSaslConfigs(SASL_OAUTHBEARER_TOKEN_ENDPOINT_URL, "http://www.example.com");
System.setProperty(ALLOWED_SASL_OAUTHBEARER_URLS_CONFIG, "http://www.example.com");
Map<String, Object> jaasConfigs = new HashMap<>();
jaasConfigs.put(CLIENT_ID_CONFIG, "an ID");
jaasConfigs.put(CLIENT_SECRET_CONFIG, "a secret");
DefaultJwtRetriever jwtRetriever = new DefaultJwtRetriever(
configs,
OAuthBearerLoginModule.OAUTHBEARER_MECHANISM,
jaasConfigs
);
assertDoesNotThrow(jwtRetriever::init);
assertInstanceOf(HttpJwtRetriever.class, jwtRetriever.delegate());
}
@ParameterizedTest
@MethodSource("urlencodeHeaderSupplier")
public void testUrlencodeHeader(Map<String, Object> configs, boolean expectedValue) {
ConfigurationUtils cu = new ConfigurationUtils(configs);
boolean actualValue = AccessTokenRetrieverFactory.validateUrlencodeHeader(cu);
boolean actualValue = DefaultJwtRetriever.validateUrlencodeHeader(cu);
assertEquals(expectedValue, actualValue);
}

View File

@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.common.security.oauthbearer.internals.secured;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
import org.jose4j.jws.AlgorithmIdentifiers;
import org.junit.jupiter.api.Test;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.assertDoesNotThrow;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
public class DefaultJwtValidatorTest extends OAuthBearerTest {
@Test
public void testConfigureWithVerificationKeyResolver() {
AccessTokenBuilder builder = new AccessTokenBuilder()
.alg(AlgorithmIdentifiers.RSA_USING_SHA256);
CloseableVerificationKeyResolver verificationKeyResolver = createVerificationKeyResolver(builder);
Map<String, ?> configs = getSaslConfigs();
DefaultJwtValidator jwtValidator = new DefaultJwtValidator(
configs,
OAuthBearerLoginModule.OAUTHBEARER_MECHANISM,
verificationKeyResolver
);
assertDoesNotThrow(jwtValidator::init);
assertInstanceOf(BrokerJwtValidator.class, jwtValidator.delegate());
}
@Test
public void testConfigureWithoutVerificationKeyResolver() {
Map<String, ?> configs = getSaslConfigs();
DefaultJwtValidator jwtValidator = new DefaultJwtValidator(
configs,
OAuthBearerLoginModule.OAUTHBEARER_MECHANISM
);
assertDoesNotThrow(jwtValidator::init);
assertInstanceOf(ClientJwtValidator.class, jwtValidator.delegate());
}
private CloseableVerificationKeyResolver createVerificationKeyResolver(AccessTokenBuilder builder) {
return (jws, nestingContext) -> builder.jwk().getPublicKey();
}
}

View File

@ -39,20 +39,20 @@ import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
public class HttpJwtRetrieverTest extends OAuthBearerTest {
@Test
public void test() throws IOException {
String expectedResponse = "Hiya, buddy";
HttpURLConnection mockedCon = createHttpURLConnection(expectedResponse);
String response = HttpAccessTokenRetriever.post(mockedCon, null, null, null, null);
String response = HttpJwtRetriever.post(mockedCon, null, null, null, null);
assertEquals(expectedResponse, response);
}
@Test
public void testEmptyResponse() throws IOException {
HttpURLConnection mockedCon = createHttpURLConnection("");
assertThrows(IOException.class, () -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null));
assertThrows(IOException.class, () -> HttpJwtRetriever.post(mockedCon, null, null, null, null));
}
@Test
@ -60,7 +60,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
HttpURLConnection mockedCon = createHttpURLConnection("dummy");
when(mockedCon.getInputStream()).thenThrow(new IOException("Can't read"));
assertThrows(IOException.class, () -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null));
assertThrows(IOException.class, () -> HttpJwtRetriever.post(mockedCon, null, null, null, null));
}
@Test
@ -72,7 +72,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
.getBytes(StandardCharsets.UTF_8)));
when(mockedCon.getResponseCode()).thenReturn(HttpURLConnection.HTTP_BAD_REQUEST);
UnretryableException ioe = assertThrows(UnretryableException.class,
() -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null));
() -> HttpJwtRetriever.post(mockedCon, null, null, null, null));
assertTrue(ioe.getMessage().contains("{\"some_arg\" - \"some problem with arg\"}"));
}
@ -85,7 +85,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
.getBytes(StandardCharsets.UTF_8)));
when(mockedCon.getResponseCode()).thenReturn(HttpURLConnection.HTTP_INTERNAL_ERROR);
IOException ioe = assertThrows(IOException.class,
() -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null));
() -> HttpJwtRetriever.post(mockedCon, null, null, null, null));
assertTrue(ioe.getMessage().contains("{\"some_arg\" - \"some problem with arg\"}"));
// error response body has different keys
@ -93,7 +93,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
"{\"errorCode\":\"some_arg\", \"errorSummary\":\"some problem with arg\"}"
.getBytes(StandardCharsets.UTF_8)));
ioe = assertThrows(IOException.class,
() -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null));
() -> HttpJwtRetriever.post(mockedCon, null, null, null, null));
assertTrue(ioe.getMessage().contains("{\"some_arg\" - \"some problem with arg\"}"));
// error response is valid json but unknown keys
@ -101,7 +101,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
"{\"err\":\"some_arg\", \"err_des\":\"some problem with arg\"}"
.getBytes(StandardCharsets.UTF_8)));
ioe = assertThrows(IOException.class,
() -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null));
() -> HttpJwtRetriever.post(mockedCon, null, null, null, null));
assertTrue(ioe.getMessage().contains("{\"err\":\"some_arg\", \"err_des\":\"some problem with arg\"}"));
}
@ -113,7 +113,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
"non json error output".getBytes(StandardCharsets.UTF_8)));
when(mockedCon.getResponseCode()).thenReturn(HttpURLConnection.HTTP_INTERNAL_ERROR);
IOException ioe = assertThrows(IOException.class,
() -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null));
() -> HttpJwtRetriever.post(mockedCon, null, null, null, null));
assertTrue(ioe.getMessage().contains("{non json error output}"));
}
@ -124,7 +124,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
r.nextBytes(expected);
InputStream in = new ByteArrayInputStream(expected);
ByteArrayOutputStream out = new ByteArrayOutputStream();
HttpAccessTokenRetriever.copy(in, out);
HttpJwtRetriever.copy(in, out);
assertArrayEquals(expected, out.toByteArray());
}
@ -133,7 +133,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
InputStream mockedIn = mock(InputStream.class);
OutputStream out = new ByteArrayOutputStream();
when(mockedIn.read(any(byte[].class))).thenThrow(new IOException());
assertThrows(IOException.class, () -> HttpAccessTokenRetriever.copy(mockedIn, out));
assertThrows(IOException.class, () -> HttpJwtRetriever.copy(mockedIn, out));
}
@Test
@ -143,7 +143,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
ObjectNode node = mapper.createObjectNode();
node.put("access_token", expected);
String actual = HttpAccessTokenRetriever.parseAccessToken(mapper.writeValueAsString(node));
String actual = HttpJwtRetriever.parseAccessToken(mapper.writeValueAsString(node));
assertEquals(expected, actual);
}
@ -153,7 +153,7 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
ObjectNode node = mapper.createObjectNode();
node.put("access_token", "");
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.parseAccessToken(mapper.writeValueAsString(node)));
assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.parseAccessToken(mapper.writeValueAsString(node)));
}
@Test
@ -162,12 +162,12 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
ObjectNode node = mapper.createObjectNode();
node.put("sub", "jdoe");
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.parseAccessToken(mapper.writeValueAsString(node)));
assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.parseAccessToken(mapper.writeValueAsString(node)));
}
@Test
public void testParseAccessTokenInvalidJson() {
assertThrows(IOException.class, () -> HttpAccessTokenRetriever.parseAccessToken("not valid JSON"));
assertThrows(IOException.class, () -> HttpJwtRetriever.parseAccessToken("not valid JSON"));
}
@Test
@ -184,27 +184,27 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
}
private void assertAuthorizationHeader(String clientId, String clientSecret, boolean urlencode, String expected) {
String actual = HttpAccessTokenRetriever.formatAuthorizationHeader(clientId, clientSecret, urlencode);
String actual = HttpJwtRetriever.formatAuthorizationHeader(clientId, clientSecret, urlencode);
assertEquals(expected, actual, String.format("Expected the HTTP Authorization header generated for client ID \"%s\" and client secret \"%s\" to match", clientId, clientSecret));
}
@Test
public void testFormatAuthorizationHeaderMissingValues() {
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader(null, "secret", false));
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader("id", null, false));
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader(null, null, false));
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader("", "secret", false));
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader("id", "", false));
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader("", "", false));
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader(" ", "secret", false));
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader("id", " ", false));
assertThrows(IllegalArgumentException.class, () -> HttpAccessTokenRetriever.formatAuthorizationHeader(" ", " ", false));
assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader(null, "secret", false));
assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader("id", null, false));
assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader(null, null, false));
assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader("", "secret", false));
assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader("id", "", false));
assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader("", "", false));
assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader(" ", "secret", false));
assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader("id", " ", false));
assertThrows(IllegalArgumentException.class, () -> HttpJwtRetriever.formatAuthorizationHeader(" ", " ", false));
}
@Test
public void testFormatRequestBody() {
String expected = "grant_type=client_credentials&scope=scope";
String actual = HttpAccessTokenRetriever.formatRequestBody("scope");
String actual = HttpJwtRetriever.formatRequestBody("scope");
assertEquals(expected, actual);
}
@ -214,24 +214,24 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
String exclamationMark = "%21";
String expected = String.format("grant_type=client_credentials&scope=earth+is+great%s", exclamationMark);
String actual = HttpAccessTokenRetriever.formatRequestBody("earth is great!");
String actual = HttpJwtRetriever.formatRequestBody("earth is great!");
assertEquals(expected, actual);
expected = String.format("grant_type=client_credentials&scope=what+on+earth%s%s%s%s%s", questionMark, exclamationMark, questionMark, exclamationMark, questionMark);
actual = HttpAccessTokenRetriever.formatRequestBody("what on earth?!?!?");
actual = HttpJwtRetriever.formatRequestBody("what on earth?!?!?");
assertEquals(expected, actual);
}
@Test
public void testFormatRequestBodyMissingValues() {
String expected = "grant_type=client_credentials";
String actual = HttpAccessTokenRetriever.formatRequestBody(null);
String actual = HttpJwtRetriever.formatRequestBody(null);
assertEquals(expected, actual);
actual = HttpAccessTokenRetriever.formatRequestBody("");
actual = HttpJwtRetriever.formatRequestBody("");
assertEquals(expected, actual);
actual = HttpAccessTokenRetriever.formatRequestBody(" ");
actual = HttpJwtRetriever.formatRequestBody(" ");
assertEquals(expected, actual);
}

View File

@ -26,42 +26,42 @@ import org.junit.jupiter.api.TestInstance.Lifecycle;
import static org.junit.jupiter.api.Assertions.assertThrows;
@TestInstance(Lifecycle.PER_CLASS)
public abstract class AccessTokenValidatorTest extends OAuthBearerTest {
public abstract class JwtValidatorTest extends OAuthBearerTest {
protected abstract AccessTokenValidator createAccessTokenValidator(AccessTokenBuilder accessTokenBuilder) throws Exception;
protected abstract JwtValidator createJwtValidator(AccessTokenBuilder accessTokenBuilder) throws Exception;
protected AccessTokenValidator createAccessTokenValidator() throws Exception {
protected JwtValidator createJwtValidator() throws Exception {
AccessTokenBuilder builder = new AccessTokenBuilder();
return createAccessTokenValidator(builder);
return createJwtValidator(builder);
}
@Test
public void testNull() throws Exception {
AccessTokenValidator validator = createAccessTokenValidator();
JwtValidator validator = createJwtValidator();
assertThrowsWithMessage(ValidateException.class, () -> validator.validate(null), "Malformed JWT provided; expected three sections (header, payload, and signature)");
}
@Test
public void testEmptyString() throws Exception {
AccessTokenValidator validator = createAccessTokenValidator();
JwtValidator validator = createJwtValidator();
assertThrowsWithMessage(ValidateException.class, () -> validator.validate(""), "Malformed JWT provided; expected three sections (header, payload, and signature)");
}
@Test
public void testWhitespace() throws Exception {
AccessTokenValidator validator = createAccessTokenValidator();
JwtValidator validator = createJwtValidator();
assertThrowsWithMessage(ValidateException.class, () -> validator.validate(" "), "Malformed JWT provided; expected three sections (header, payload, and signature)");
}
@Test
public void testEmptySections() throws Exception {
AccessTokenValidator validator = createAccessTokenValidator();
JwtValidator validator = createJwtValidator();
assertThrowsWithMessage(ValidateException.class, () -> validator.validate(".."), "Malformed JWT provided; expected three sections (header, payload, and signature)");
}
@Test
public void testMissingHeader() throws Exception {
AccessTokenValidator validator = createAccessTokenValidator();
JwtValidator validator = createJwtValidator();
String header = "";
String payload = createBase64JsonJwtSection(node -> { });
String signature = "";
@ -71,7 +71,7 @@ public abstract class AccessTokenValidatorTest extends OAuthBearerTest {
@Test
public void testMissingPayload() throws Exception {
AccessTokenValidator validator = createAccessTokenValidator();
JwtValidator validator = createJwtValidator();
String header = createBase64JsonJwtSection(node -> node.put(HeaderParameterNames.ALGORITHM, AlgorithmIdentifiers.NONE));
String payload = "";
String signature = "";
@ -81,7 +81,7 @@ public abstract class AccessTokenValidatorTest extends OAuthBearerTest {
@Test
public void testMissingSignature() throws Exception {
AccessTokenValidator validator = createAccessTokenValidator();
JwtValidator validator = createJwtValidator();
String header = createBase64JsonJwtSection(node -> node.put(HeaderParameterNames.ALGORITHM, AlgorithmIdentifiers.NONE));
String payload = createBase64JsonJwtSection(node -> { });
String signature = "";

View File

@ -19,9 +19,6 @@ package org.apache.kafka.common.security.oauthbearer.internals.secured;
import org.apache.kafka.common.config.AbstractConfig;
import org.apache.kafka.common.config.ConfigDef;
import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
import org.apache.kafka.common.security.authenticator.TestJaasConfig;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
import org.apache.kafka.common.utils.Utils;
import com.fasterxml.jackson.databind.ObjectMapper;
@ -52,8 +49,6 @@ import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.function.Consumer;
import javax.security.auth.login.AppConfigurationEntry;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
@ -80,18 +75,6 @@ public abstract class OAuthBearerTest {
expectedSubstring));
}
protected void configureHandler(AuthenticateCallbackHandler handler,
Map<String, ?> configs,
Map<String, Object> jaasConfig) {
TestJaasConfig config = new TestJaasConfig();
config.createOrUpdateEntry("KafkaClient", OAuthBearerLoginModule.class.getName(), jaasConfig);
AppConfigurationEntry kafkaClient = config.getAppConfigurationEntry("KafkaClient")[0];
handler.configure(configs,
OAuthBearerLoginModule.OAUTHBEARER_MECHANISM,
Collections.singletonList(kafkaClient));
}
protected String createBase64JsonJwtSection(Consumer<ObjectNode> c) {
String json = createJsonJwtSection(c);
@ -212,4 +195,11 @@ public abstract class OAuthBearerTest {
return jwk;
}
protected String createAccessKey(String header, String payload, String signature) {
Base64.Encoder enc = Base64.getEncoder();
header = enc.encodeToString(Utils.utf8(header));
payload = enc.encodeToString(Utils.utf8(payload));
signature = enc.encodeToString(Utils.utf8(signature));
return String.format("%s.%s.%s", header, payload, signature);
}
}

View File

@ -24,11 +24,12 @@ import org.apache.kafka.common.config.ConfigException;
import org.apache.kafka.common.config.SaslConfigs;
import org.apache.kafka.common.config.SslConfigs;
import org.apache.kafka.common.config.types.Password;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenRetriever;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenRetrieverFactory;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.AccessTokenValidatorFactory;
import org.apache.kafka.common.security.oauthbearer.OAuthBearerLoginModule;
import org.apache.kafka.common.security.oauthbearer.internals.secured.CloseableVerificationKeyResolver;
import org.apache.kafka.common.security.oauthbearer.internals.secured.DefaultJwtRetriever;
import org.apache.kafka.common.security.oauthbearer.internals.secured.DefaultJwtValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.JwtRetriever;
import org.apache.kafka.common.security.oauthbearer.internals.secured.JwtValidator;
import org.apache.kafka.common.security.oauthbearer.internals.secured.VerificationKeyResolverFactory;
import org.apache.kafka.common.utils.Exit;
@ -139,16 +140,19 @@ public class OAuthCompatibilityTool {
{
// Client side...
try (AccessTokenRetriever atr = AccessTokenRetrieverFactory.create(configs, jaasConfigs)) {
try (JwtRetriever atr = new DefaultJwtRetriever(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, jaasConfigs)) {
atr.init();
AccessTokenValidator atv = AccessTokenValidatorFactory.create(configs);
System.out.println("PASSED 1/5: client configuration");
accessToken = atr.retrieve();
System.out.println("PASSED 2/5: client JWT retrieval");
try (JwtValidator atv = new DefaultJwtValidator(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM)) {
atv.init();
System.out.println("PASSED 1/5: client configuration");
atv.validate(accessToken);
System.out.println("PASSED 3/5: client JWT validation");
accessToken = atr.retrieve();
System.out.println("PASSED 2/5: client JWT retrieval");
atv.validate(accessToken);
System.out.println("PASSED 3/5: client JWT validation");
}
}
}
@ -156,11 +160,14 @@ public class OAuthCompatibilityTool {
// Broker side...
try (CloseableVerificationKeyResolver vkr = VerificationKeyResolverFactory.create(configs, jaasConfigs)) {
vkr.init();
AccessTokenValidator atv = AccessTokenValidatorFactory.create(configs, vkr);
System.out.println("PASSED 4/5: broker configuration");
atv.validate(accessToken);
System.out.println("PASSED 5/5: broker JWT validation");
try (JwtValidator atv = new DefaultJwtValidator(configs, OAuthBearerLoginModule.OAUTHBEARER_MECHANISM, vkr)) {
atv.init();
System.out.println("PASSED 4/5: broker configuration");
atv.validate(accessToken);
System.out.println("PASSED 5/5: broker JWT validation");
}
}
}