diff --git a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/HttpAccessTokenRetriever.java b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/HttpAccessTokenRetriever.java index b52952a16f1..b92a6c3ea1f 100644 --- a/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/HttpAccessTokenRetriever.java +++ b/clients/src/main/java/org/apache/kafka/common/security/oauthbearer/secured/HttpAccessTokenRetriever.java @@ -240,6 +240,7 @@ public class HttpAccessTokenRetriever implements AccessTokenRetriever { log.debug("handleOutput - responseCode: {}", responseCode); String responseBody = null; + String errorResponseBody = null; try (InputStream is = con.getInputStream()) { ByteArrayOutputStream os = new ByteArrayOutputStream(); @@ -247,27 +248,41 @@ public class HttpAccessTokenRetriever implements AccessTokenRetriever { copy(is, os); responseBody = os.toString(StandardCharsets.UTF_8.name()); } catch (Exception e) { + // there still can be useful error response from the servers, lets get it + try (InputStream is = con.getErrorStream()) { + ByteArrayOutputStream os = new ByteArrayOutputStream(); + log.debug("handleOutput - preparing to read error response body from {}", con.getURL()); + copy(is, os); + errorResponseBody = os.toString(StandardCharsets.UTF_8.name()); + } catch (Exception e2) { + log.warn("handleOutput - error retrieving error information", e2); + } log.warn("handleOutput - error retrieving data", e); } if (responseCode == HttpURLConnection.HTTP_OK || responseCode == HttpURLConnection.HTTP_CREATED) { - log.debug("handleOutput - responseCode: {}, response: {}", responseCode, responseBody); + log.debug("handleOutput - responseCode: {}, response: {}, error response: {}", responseCode, responseBody, + errorResponseBody); if (responseBody == null || responseBody.isEmpty()) - throw new IOException(String.format("The token endpoint response was unexpectedly empty despite response code %s from %s", responseCode, con.getURL())); + throw new IOException(String.format("The token endpoint response was unexpectedly empty despite response code %s from %s and error message %s", + responseCode, con.getURL(), formatErrorMessage(errorResponseBody))); return responseBody; } else { - log.warn("handleOutput - error response code: {}, error response body: {}", responseCode, responseBody); + log.warn("handleOutput - error response code: {}, response body: {}, error response body: {}", responseCode, + responseBody, errorResponseBody); if (UNRETRYABLE_HTTP_CODES.contains(responseCode)) { // We know that this is a non-transient error, so let's not keep retrying the // request unnecessarily. - throw new UnretryableException(new IOException(String.format("The response code %s was encountered reading the token endpoint response; will not attempt further retries", responseCode))); + throw new UnretryableException(new IOException(String.format("The response code %s and error response %s was encountered reading the token endpoint response; will not attempt further retries", + responseCode, formatErrorMessage(errorResponseBody)))); } else { // We don't know if this is a transient (retryable) error or not, so let's assume // it is. - throw new IOException(String.format("The unexpected response code %s was encountered reading the token endpoint response", responseCode)); + throw new IOException(String.format("The unexpected response code %s and error message %s was encountered reading the token endpoint response", + responseCode, formatErrorMessage(errorResponseBody))); } } } @@ -280,6 +295,26 @@ public class HttpAccessTokenRetriever implements AccessTokenRetriever { os.write(buf, 0, b); } + static String formatErrorMessage(String errorResponseBody) { + if (errorResponseBody == null || errorResponseBody.trim().equals("")) { + return "{}"; + } + ObjectMapper mapper = new ObjectMapper(); + try { + JsonNode rootNode = mapper.readTree(errorResponseBody); + if (!rootNode.at("/error").isMissingNode()) { + return String.format("{%s - %s}", rootNode.at("/error"), rootNode.at("/error_description")); + } else if (!rootNode.at("/errorCode").isMissingNode()) { + return String.format("{%s - %s}", rootNode.at("/errorCode"), rootNode.at("/errorSummary")); + } else { + return errorResponseBody; + } + } catch (Exception e) { + log.warn("Error parsing error response", e); + } + return String.format("{%s}", errorResponseBody); + } + static String parseAccessToken(String responseBody) throws IOException { log.debug("parseAccessToken - responseBody: {}", responseBody); ObjectMapper mapper = new ObjectMapper(); diff --git a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/HttpAccessTokenRetrieverTest.java b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/HttpAccessTokenRetrieverTest.java index de3b4634470..66252ffd52f 100644 --- a/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/HttpAccessTokenRetrieverTest.java +++ b/clients/src/test/java/org/apache/kafka/common/security/oauthbearer/secured/HttpAccessTokenRetrieverTest.java @@ -20,6 +20,7 @@ package org.apache.kafka.common.security.oauthbearer.secured; import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -32,6 +33,7 @@ import java.io.IOException; import java.io.InputStream; import java.io.OutputStream; import java.net.HttpURLConnection; +import java.nio.charset.StandardCharsets; import java.util.Base64; import java.util.Random; import org.apache.kafka.common.utils.Utils; @@ -61,6 +63,60 @@ public class HttpAccessTokenRetrieverTest extends OAuthBearerTest { assertThrows(IOException.class, () -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null)); } + @Test + public void testErrorResponseUnretryableCode() throws IOException { + HttpURLConnection mockedCon = createHttpURLConnection("dummy"); + when(mockedCon.getInputStream()).thenThrow(new IOException("Can't read")); + when(mockedCon.getErrorStream()).thenReturn(new ByteArrayInputStream( + "{\"error\":\"some_arg\", \"error_description\":\"some problem with arg\"}" + .getBytes(StandardCharsets.UTF_8))); + when(mockedCon.getResponseCode()).thenReturn(HttpURLConnection.HTTP_BAD_REQUEST); + UnretryableException ioe = assertThrows(UnretryableException.class, + () -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null)); + assertTrue(ioe.getMessage().contains("{\"some_arg\" - \"some problem with arg\"}")); + } + + @Test + public void testErrorResponseRetryableCode() throws IOException { + HttpURLConnection mockedCon = createHttpURLConnection("dummy"); + when(mockedCon.getInputStream()).thenThrow(new IOException("Can't read")); + when(mockedCon.getErrorStream()).thenReturn(new ByteArrayInputStream( + "{\"error\":\"some_arg\", \"error_description\":\"some problem with arg\"}" + .getBytes(StandardCharsets.UTF_8))); + when(mockedCon.getResponseCode()).thenReturn(HttpURLConnection.HTTP_INTERNAL_ERROR); + IOException ioe = assertThrows(IOException.class, + () -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null)); + assertTrue(ioe.getMessage().contains("{\"some_arg\" - \"some problem with arg\"}")); + + // error response body has different keys + when(mockedCon.getErrorStream()).thenReturn(new ByteArrayInputStream( + "{\"errorCode\":\"some_arg\", \"errorSummary\":\"some problem with arg\"}" + .getBytes(StandardCharsets.UTF_8))); + ioe = assertThrows(IOException.class, + () -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null)); + assertTrue(ioe.getMessage().contains("{\"some_arg\" - \"some problem with arg\"}")); + + // error response is valid json but unknown keys + when(mockedCon.getErrorStream()).thenReturn(new ByteArrayInputStream( + "{\"err\":\"some_arg\", \"err_des\":\"some problem with arg\"}" + .getBytes(StandardCharsets.UTF_8))); + ioe = assertThrows(IOException.class, + () -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null)); + assertTrue(ioe.getMessage().contains("{\"err\":\"some_arg\", \"err_des\":\"some problem with arg\"}")); + } + + @Test + public void testErrorResponseIsInvalidJson() throws IOException { + HttpURLConnection mockedCon = createHttpURLConnection("dummy"); + when(mockedCon.getInputStream()).thenThrow(new IOException("Can't read")); + when(mockedCon.getErrorStream()).thenReturn(new ByteArrayInputStream( + "non json error output".getBytes(StandardCharsets.UTF_8))); + when(mockedCon.getResponseCode()).thenReturn(HttpURLConnection.HTTP_INTERNAL_ERROR); + IOException ioe = assertThrows(IOException.class, + () -> HttpAccessTokenRetriever.post(mockedCon, null, null, null, null)); + assertTrue(ioe.getMessage().contains("{non json error output}")); + } + @Test public void testCopy() throws IOException { byte[] expected = new byte[4096 + 1];