KAFKA-14212: Enhance HttpAccessTokenRetriever to retrieve error message (#12651)

Currently HttpAccessTokenRetriever client side class does not retrieve error response from the token e/p. As a result, seemingly trivial config issues could take a lot of time to diagnose and fix. For example, client could be sending invalid client secret, id or scope.
This PR aims to remedy the situation by retrieving the error response, if present and logging as well as appending to any exceptions thrown.
New unit tests have also been added.

Reviewers: Manikumar Reddy <manikumar.reddy@gmail.com>
This commit is contained in:
Sushant Mahajan 2022-09-20 12:33:19 +05:30 committed by GitHub
parent 8c8b5366a6
commit f8e0a6d924
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 96 additions and 5 deletions

View File

@ -240,6 +240,7 @@ public class HttpAccessTokenRetriever implements AccessTokenRetriever {
log.debug("handleOutput - responseCode: {}", responseCode);
String responseBody = null;
String errorResponseBody = null;
try (InputStream is = con.getInputStream()) {
ByteArrayOutputStream os = new ByteArrayOutputStream();
@ -247,27 +248,41 @@ public class HttpAccessTokenRetriever implements AccessTokenRetriever {
copy(is, os);
responseBody = os.toString(StandardCharsets.UTF_8.name());
} catch (Exception e) {
// there still can be useful error response from the servers, lets get it
try (InputStream is = con.getErrorStream()) {
ByteArrayOutputStream os = new ByteArrayOutputStream();
log.debug("handleOutput - preparing to read error response body from {}", con.getURL());
copy(is, os);
errorResponseBody = os.toString(StandardCharsets.UTF_8.name());
} catch (Exception e2) {
log.warn("handleOutput - error retrieving error information", e2);
}
log.warn("handleOutput - error retrieving data", e);
}
if (responseCode == HttpURLConnection.HTTP_OK || responseCode == HttpURLConnection.HTTP_CREATED) {
log.debug("handleOutput - responseCode: {}, response: {}", responseCode, responseBody);
log.debug("handleOutput - responseCode: {}, response: {}, error response: {}", responseCode, responseBody,
errorResponseBody);
if (responseBody == null || responseBody.isEmpty())
throw new IOException(String.format("The token endpoint response was unexpectedly empty despite response code %s from %s", responseCode, con.getURL()));
throw new IOException(String.format("The token endpoint response was unexpectedly empty despite response code %s from %s and error message %s",
responseCode, con.getURL(), formatErrorMessage(errorResponseBody)));
return responseBody;
} else {
log.warn("handleOutput - error response code: {}, error response body: {}", responseCode, responseBody);
log.warn("handleOutput - error response code: {}, response body: {}, error response body: {}", responseCode,
responseBody, errorResponseBody);
if (UNRETRYABLE_HTTP_CODES.contains(responseCode)) {
// We know that this is a non-transient error, so let's not keep retrying the
// request unnecessarily.
throw new UnretryableException(new IOException(String.format("The response code %s was encountered reading the token endpoint response; will not attempt further retries", responseCode)));
throw new UnretryableException(new IOException(String.format("The response code %s and error response %s was encountered reading the token endpoint response; will not attempt further retries",
responseCode, formatErrorMessage(errorResponseBody))));
} else {
// We don't know if this is a transient (retryable) error or not, so let's assume
// it is.
throw new IOException(String.format("The unexpected response code %s was encountered reading the token endpoint response", responseCode));
throw new IOException(String.format("The unexpected response code %s and error message %s was encountered reading the token endpoint response",
responseCode, formatErrorMessage(errorResponseBody)));
}
}
}
@ -280,6 +295,26 @@ public class HttpAccessTokenRetriever implements AccessTokenRetriever {
os.write(buf, 0, b);
}
static String formatErrorMessage(String errorResponseBody) {
if (errorResponseBody == null || errorResponseBody.trim().equals("")) {
return "{}";
}
ObjectMapper mapper = new ObjectMapper();
try {
JsonNode rootNode = mapper.readTree(errorResponseBody);
if (!rootNode.at("/error").isMissingNode()) {
return String.format("{%s - %s}", rootNode.at("/error"), rootNode.at("/error_description"));
} else if (!rootNode.at("/errorCode").isMissingNode()) {
return String.format("{%s - %s}", rootNode.at("/errorCode"), rootNode.at("/errorSummary"));
} else {
return errorResponseBody;
}
} catch (Exception e) {
log.warn("Error parsing error response", e);
}
return String.format("{%s}", errorResponseBody);
}
static String parseAccessToken(String responseBody) throws IOException {
log.debug("parseAccessToken - responseBody: {}", responseBody);
ObjectMapper mapper = new ObjectMapper();

View File

@ -20,6 +20,7 @@ package org.apache.kafka.common.security.oauthbearer.secured;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
@ -32,6 +33,7 @@ import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.HttpURLConnection;
import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.Random;
import org.apache.kafka.common.utils.Utils;
@ -61,6 +63,60 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest {
assertThrows(IOException.class, () -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null));
}
@Test
public void testErrorResponseUnretryableCode() throws IOException {
HttpURLConnection mockedCon = createHttpURLConnection("dummy");
when(mockedCon.getInputStream()).thenThrow(new IOException("Can't read"));
when(mockedCon.getErrorStream()).thenReturn(new ByteArrayInputStream(
"{\"error\":\"some_arg\", \"error_description\":\"some problem with arg\"}"
.getBytes(StandardCharsets.UTF_8)));
when(mockedCon.getResponseCode()).thenReturn(HttpURLConnection.HTTP_BAD_REQUEST);
UnretryableException ioe = assertThrows(UnretryableException.class,
() -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null));
assertTrue(ioe.getMessage().contains("{\"some_arg\" - \"some problem with arg\"}"));
}
@Test
public void testErrorResponseRetryableCode() throws IOException {
HttpURLConnection mockedCon = createHttpURLConnection("dummy");
when(mockedCon.getInputStream()).thenThrow(new IOException("Can't read"));
when(mockedCon.getErrorStream()).thenReturn(new ByteArrayInputStream(
"{\"error\":\"some_arg\", \"error_description\":\"some problem with arg\"}"
.getBytes(StandardCharsets.UTF_8)));
when(mockedCon.getResponseCode()).thenReturn(HttpURLConnection.HTTP_INTERNAL_ERROR);
IOException ioe = assertThrows(IOException.class,
() -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null));
assertTrue(ioe.getMessage().contains("{\"some_arg\" - \"some problem with arg\"}"));
// error response body has different keys
when(mockedCon.getErrorStream()).thenReturn(new ByteArrayInputStream(
"{\"errorCode\":\"some_arg\", \"errorSummary\":\"some problem with arg\"}"
.getBytes(StandardCharsets.UTF_8)));
ioe = assertThrows(IOException.class,
() -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null));
assertTrue(ioe.getMessage().contains("{\"some_arg\" - \"some problem with arg\"}"));
// error response is valid json but unknown keys
when(mockedCon.getErrorStream()).thenReturn(new ByteArrayInputStream(
"{\"err\":\"some_arg\", \"err_des\":\"some problem with arg\"}"
.getBytes(StandardCharsets.UTF_8)));
ioe = assertThrows(IOException.class,
() -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null));
assertTrue(ioe.getMessage().contains("{\"err\":\"some_arg\", \"err_des\":\"some problem with arg\"}"));
}
@Test
public void testErrorResponseIsInvalidJson() throws IOException {
HttpURLConnection mockedCon = createHttpURLConnection("dummy");
when(mockedCon.getInputStream()).thenThrow(new IOException("Can't read"));
when(mockedCon.getErrorStream()).thenReturn(new ByteArrayInputStream(
"non json error output".getBytes(StandardCharsets.UTF_8)));
when(mockedCon.getResponseCode()).thenReturn(HttpURLConnection.HTTP_INTERNAL_ERROR);
IOException ioe = assertThrows(IOException.class,
() -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null));
assertTrue(ioe.getMessage().contains("{non json error output}"));
}
@Test
public void testCopy() throws IOException {
byte[] expected = new byte[4096 + 1];