KAFKA-5863: Avoid NPE when RestClient calls expecting no-content receive content. (#13294)

Signed-off-by: Greg Harris <greg.harris@aiven.io>
Reviewers: Hector Geraldino <hgeraldino@gmail.com>, Yash Mayya <yash.mayya@gmail.com>
This commit is contained in:
Greg Harris 2024-01-18 11:41:27 -08:00 committed by GitHub
parent a989329ee3
commit 72b70288eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 145 additions and 61 deletions

View File

@ -1268,7 +1268,7 @@ public class DistributedHerder extends AbstractHerder implements Runnable {
try {
String stageDescription = "Forwarding zombie fencing request to the leader at " + workerUrl;
try (TemporaryStage stage = new TemporaryStage(stageDescription, callback, time)) {
restClient.httpRequest(fenceUrl, "PUT", null, null, null, sessionKey, requestSignatureAlgorithm);
restClient.httpRequest(fenceUrl, "PUT", null, null, sessionKey, requestSignatureAlgorithm);
}
callback.onCompletion(null, null);
} catch (Throwable t) {
@ -2224,7 +2224,7 @@ public class DistributedHerder extends AbstractHerder implements Runnable {
log.trace("Forwarding task configurations for connector {} to leader", connName);
String stageDescription = "Forwarding task configurations to the leader at " + leaderUrl;
try (TemporaryStage stage = new TemporaryStage(stageDescription, cb, time)) {
restClient.httpRequest(reconfigUrl, "POST", null, rawTaskProps, null, sessionKey, requestSignatureAlgorithm);
restClient.httpRequest(reconfigUrl, "POST", null, rawTaskProps, sessionKey, requestSignatureAlgorithm);
}
cb.onCompletion(null, null);
} catch (ConnectException e) {

View File

@ -141,9 +141,14 @@ public class HerderRequestHandler {
return completeOrForwardRequest(cb, path, method, headers, null, body, resultType, translator, forward);
}
public <T> T completeOrForwardRequest(FutureCallback<T> cb, String path, String method, HttpHeaders headers,
Object body, Boolean forward) throws Throwable {
return completeOrForwardRequest(cb, path, method, headers, body, null, new IdentityTranslator<>(), forward);
public <T> T completeOrForwardRequest(FutureCallback<T> cb, String path, String method, HttpHeaders headers, Object body,
TypeReference<T> resultType, Boolean forward) throws Throwable {
return completeOrForwardRequest(cb, path, method, headers, body, resultType, new IdentityTranslator<>(), forward);
}
public void completeOrForwardRequest(FutureCallback<Void> cb, String path, String method, HttpHeaders headers, Object body,
Boolean forward) throws Throwable {
completeOrForwardRequest(cb, path, method, headers, body, new TypeReference<Void>() { }, new IdentityTranslator<>(), forward);
}
public interface Translator<T, U> {

View File

@ -42,6 +42,7 @@ import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeoutException;
@ -67,13 +68,13 @@ public class RestClient {
/**
* Sends HTTP request to remote REST server
*
* @param url HTTP connection will be established with this url.
* @param method HTTP method ("GET", "POST", "PUT", etc.)
* @param url HTTP connection will be established with this url, non-null.
* @param method HTTP method ("GET", "POST", "PUT", etc.), non-null
* @param headers HTTP headers from REST endpoint
* @param requestBodyData Object to serialize as JSON and send in the request body.
* @param responseFormat Expected format of the response to the HTTP request.
* @param responseFormat Expected format of the response to the HTTP request, non-null.
* @param <T> The type of the deserialized response to the HTTP request.
* @return The deserialized response to the HTTP request, or null if no data is expected.
* @return The deserialized response to the HTTP request, containing null if no data is expected or returned.
*/
public <T> HttpResponse<T> httpRequest(String url, String method, HttpHeaders headers, Object requestBodyData,
TypeReference<T> responseFormat) {
@ -83,21 +84,41 @@ public class RestClient {
/**
* Sends HTTP request to remote REST server
*
* @param url HTTP connection will be established with this url.
* @param method HTTP method ("GET", "POST", "PUT", etc.)
* @param url HTTP connection will be established with this url, non-null.
* @param method HTTP method ("GET", "POST", "PUT", etc.), non-null
* @param headers HTTP headers from REST endpoint
* @param requestBodyData Object to serialize as JSON and send in the request body.
* @param responseFormat Expected format of the response to the HTTP request.
* @param sessionKey The key to sign the request with (intended for internal requests only);
* may be null if the request doesn't need to be signed
* @param requestSignatureAlgorithm The algorithm to sign the request with (intended for internal requests only);
* may be null if the request doesn't need to be signed
*/
public void httpRequest(String url, String method, HttpHeaders headers, Object requestBodyData,
SecretKey sessionKey, String requestSignatureAlgorithm) {
httpRequest(url, method, headers, requestBodyData, new TypeReference<Void>() { }, sessionKey, requestSignatureAlgorithm);
}
/**
* Sends HTTP request to remote REST server
*
* @param url HTTP connection will be established with this url, non-null.
* @param method HTTP method ("GET", "POST", "PUT", etc.), non-null
* @param headers HTTP headers from REST endpoint
* @param requestBodyData Object to serialize as JSON and send in the request body.
* @param responseFormat Expected format of the response to the HTTP request, non-null.
* @param <T> The type of the deserialized response to the HTTP request.
* @param sessionKey The key to sign the request with (intended for internal requests only);
* may be null if the request doesn't need to be signed
* @param requestSignatureAlgorithm The algorithm to sign the request with (intended for internal requests only);
* may be null if the request doesn't need to be signed
* @return The deserialized response to the HTTP request, or null if no data is expected.
* @return The deserialized response to the HTTP request, containing null if no data is expected or returned.
*/
public <T> HttpResponse<T> httpRequest(String url, String method, HttpHeaders headers, Object requestBodyData,
TypeReference<T> responseFormat,
SecretKey sessionKey, String requestSignatureAlgorithm) {
Objects.requireNonNull(url, "url must be non-null");
Objects.requireNonNull(method, "method must be non-null");
Objects.requireNonNull(responseFormat, "response format must be non-null");
// Only try to load SSL configs if we have to (see KAFKA-14816)
SslContextFactory sslContextFactory = url.startsWith("https://")
? SSLUtils.createClientSideSslContextFactory(config)

View File

@ -330,7 +330,7 @@ public class ConnectorsResource {
FutureCallback<Void> cb = new FutureCallback<>();
ConnectorTaskId taskId = new ConnectorTaskId(connector, task);
herder.restartTask(taskId, cb);
requestHandler.completeOrForwardRequest(cb, "/connectors/" + connector + "/tasks/" + task + "/restart", "POST", headers, null, forward);
requestHandler.completeOrForwardRequest(cb, "/connectors/" + connector + "/tasks/" + task + "/restart", "POST", headers, null, new TypeReference<Void>() { }, forward);
}
@DELETE
@ -341,7 +341,7 @@ public class ConnectorsResource {
final @Parameter(hidden = true) @QueryParam("forward") Boolean forward) throws Throwable {
FutureCallback<Herder.Created<ConnectorInfo>> cb = new FutureCallback<>();
herder.deleteConnectorConfig(connector, cb);
requestHandler.completeOrForwardRequest(cb, "/connectors/" + connector, "DELETE", headers, null, forward);
requestHandler.completeOrForwardRequest(cb, "/connectors/" + connector, "DELETE", headers, null, new TypeReference<Herder.Created<ConnectorInfo>>() { }, forward);
}
@GET

View File

@ -78,7 +78,6 @@ import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.junit.MockitoJUnitRunner;
import org.mockito.stubbing.Answer;
import org.mockito.stubbing.OngoingStubbing;
import javax.crypto.SecretKey;
import java.util.ArrayList;
@ -2695,15 +2694,14 @@ public class DistributedHerderTest {
expectRebalance(1, Collections.emptyList(), Collections.emptyList());
expectMemberPoll();
OngoingStubbing<RestClient.HttpResponse<Object>> expectRequest = when(restClient.httpRequest(
any(), eq("PUT"), isNull(), isNull(), isNull(), any(), any()
));
if (succeed) {
expectRequest.thenReturn(null);
} else {
expectRequest.thenThrow(new ConnectRestException(409, "Rebalance :("));
doAnswer(invocation -> {
if (!succeed) {
throw new ConnectRestException(409, "Rebalance :(");
}
return null;
}).when(restClient).httpRequest(
any(), eq("PUT"), isNull(), isNull(), any(), any()
);
ArgumentCaptor<Runnable> forwardRequest = ArgumentCaptor.forClass(Runnable.class);
@ -3291,10 +3289,10 @@ public class DistributedHerderTest {
changedTaskConfigs.add(TASK_CONFIG);
when(worker.connectorTaskConfigs(CONN1, sinkConnectorConfig)).thenReturn(changedTaskConfigs);
when(restClient.httpRequest(any(), eq("POST"), any(), any(), any(), any(), any()))
.thenThrow(new ConnectException("Request to leader to reconfigure connector tasks failed"))
.thenThrow(new ConnectException("Request to leader to reconfigure connector tasks failed"))
.thenReturn(null);
doThrow(new ConnectException("Request to leader to reconfigure connector tasks failed"))
.doThrow(new ConnectException("Request to leader to reconfigure connector tasks failed"))
.doNothing()
.when(restClient).httpRequest(any(), eq("POST"), any(), any(), any(), any());
expectAndVerifyTaskReconfigurationRetries();
}

View File

@ -62,9 +62,13 @@ import static org.mockito.Mockito.when;
public class RestClientTest {
private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper();
private static final String MOCK_URL = "http://localhost:1234/api/endpoint";
private static final String TEST_METHOD = "GET";
private static final TestDTO TEST_DTO = new TestDTO("requestBodyData");
private static final TypeReference<TestDTO> TEST_TYPE = new TypeReference<TestDTO>() {
};
private static final SecretKey MOCK_SECRET_KEY = getMockSecretKey();
private static final String TEST_SIGNATURE_ALGORITHM = "HmacSHA1";
private static void assertIsInternalServerError(ConnectRestException e) {
assertEquals(Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), e.statusCode());
@ -78,31 +82,26 @@ public class RestClientTest {
return mockKey;
}
private static RestClient.HttpResponse<TestDTO> httpRequest(HttpClient httpClient, String requestSignatureAlgorithm, boolean https) {
private static <T> RestClient.HttpResponse<T> httpRequest(
HttpClient httpClient,
String url,
String method,
TypeReference<T> responseFormat,
String requestSignatureAlgorithm
) {
RestClient client = spy(new RestClient(null));
doReturn(httpClient).when(client).httpClient(any());
String protocol = https ? "https" : "http";
String url = protocol + "://localhost:1234/api/endpoint";
return client.httpRequest(
url,
"GET",
method,
null,
new TestDTO("requestBodyData"),
TEST_TYPE,
TEST_DTO,
responseFormat,
MOCK_SECRET_KEY,
requestSignatureAlgorithm
);
}
private static RestClient.HttpResponse<TestDTO> httpRequest(HttpClient httpClient, String requestSignatureAlgorithm) {
return httpRequest(httpClient, requestSignatureAlgorithm, false);
}
private static RestClient.HttpResponse<TestDTO> httpRequest(HttpClient httpClient) throws Exception {
String validRequestSignatureAlgorithm = "HmacSHA1";
return httpRequest(httpClient, validRequestSignatureAlgorithm);
}
@RunWith(Parameterized.class)
public static class RequestFailureParameterizedTest {
@ -136,7 +135,9 @@ public class RestClientTest {
public void testFailureDuringRequestCausesInternalServerError() throws Exception {
Request request = buildThrowingMockRequest(requestException);
when(httpClient.newRequest(anyString())).thenReturn(request);
ConnectRestException e = assertThrows(ConnectRestException.class, () -> httpRequest(httpClient));
ConnectRestException e = assertThrows(ConnectRestException.class, () -> httpRequest(
httpClient, MOCK_URL, TEST_METHOD, TEST_TYPE, TEST_SIGNATURE_ALGORITHM
));
assertIsInternalServerError(e);
assertEquals(requestException, e.getCause());
}
@ -166,15 +167,46 @@ public class RestClientTest {
when(httpClient.newRequest(anyString())).thenReturn(req);
}
@Test
public void testNullUrl() throws Exception {
int statusCode = Response.Status.OK.getStatusCode();
setupHttpClient(statusCode, toJsonString(TEST_DTO));
assertThrows(NullPointerException.class, () -> httpRequest(
httpClient, null, TEST_METHOD, TEST_TYPE, TEST_SIGNATURE_ALGORITHM
));
}
@Test
public void testNullMethod() throws Exception {
int statusCode = Response.Status.OK.getStatusCode();
setupHttpClient(statusCode, toJsonString(TEST_DTO));
assertThrows(NullPointerException.class, () -> httpRequest(
httpClient, MOCK_URL, null, TEST_TYPE, TEST_SIGNATURE_ALGORITHM
));
}
@Test
public void testNullResponseType() throws Exception {
int statusCode = Response.Status.OK.getStatusCode();
setupHttpClient(statusCode, toJsonString(TEST_DTO));
assertThrows(NullPointerException.class, () -> httpRequest(
httpClient, MOCK_URL, TEST_METHOD, null, TEST_SIGNATURE_ALGORITHM
));
}
@Test
public void testSuccess() throws Exception {
int statusCode = Response.Status.OK.getStatusCode();
TestDTO expectedResponse = new TestDTO("someContent");
setupHttpClient(statusCode, toJsonString(expectedResponse));
setupHttpClient(statusCode, toJsonString(TEST_DTO));
RestClient.HttpResponse<TestDTO> httpResp = httpRequest(httpClient);
RestClient.HttpResponse<TestDTO> httpResp = httpRequest(
httpClient, MOCK_URL, TEST_METHOD, TEST_TYPE, TEST_SIGNATURE_ALGORITHM
);
assertEquals(statusCode, httpResp.status());
assertEquals(expectedResponse, httpResp.body());
assertEquals(TEST_DTO, httpResp.body());
}
@Test
@ -182,7 +214,9 @@ public class RestClientTest {
int statusCode = Response.Status.NO_CONTENT.getStatusCode();
setupHttpClient(statusCode, null);
RestClient.HttpResponse<TestDTO> httpResp = httpRequest(httpClient);
RestClient.HttpResponse<TestDTO> httpResp = httpRequest(
httpClient, MOCK_URL, TEST_METHOD, TEST_TYPE, TEST_SIGNATURE_ALGORITHM
);
assertEquals(statusCode, httpResp.status());
assertNull(httpResp.body());
}
@ -193,19 +227,36 @@ public class RestClientTest {
ErrorMessage errorMsg = new ErrorMessage(Response.Status.GONE.getStatusCode(), "Some Error Message");
setupHttpClient(statusCode, toJsonString(errorMsg));
ConnectRestException e = assertThrows(ConnectRestException.class, () -> httpRequest(httpClient));
ConnectRestException e = assertThrows(ConnectRestException.class, () -> httpRequest(
httpClient, MOCK_URL, TEST_METHOD, TEST_TYPE, TEST_SIGNATURE_ALGORITHM
));
assertEquals(statusCode, e.statusCode());
assertEquals(errorMsg.errorCode(), e.errorCode());
assertEquals(errorMsg.message(), e.getMessage());
}
@Test
public void testNonEmptyResponseWithVoidResponseType() throws Exception {
int statusCode = Response.Status.OK.getStatusCode();
setupHttpClient(statusCode, toJsonString(TEST_DTO));
TypeReference<Void> voidResponse = new TypeReference<Void>() { };
RestClient.HttpResponse<Void> httpResp = httpRequest(
httpClient, MOCK_URL, TEST_METHOD, voidResponse, TEST_SIGNATURE_ALGORITHM
);
assertEquals(statusCode, httpResp.status());
assertNull(httpResp.body());
}
@Test
public void testUnexpectedHttpResponseCausesInternalServerError() throws Exception {
int statusCode = Response.Status.NOT_MODIFIED.getStatusCode(); // never thrown explicitly -
// should be treated as an unexpected error and translated into 500 INTERNAL_SERVER_ERROR
setupHttpClient(statusCode, null);
ConnectRestException e = assertThrows(ConnectRestException.class, () -> httpRequest(httpClient));
ConnectRestException e = assertThrows(ConnectRestException.class, () -> httpRequest(
httpClient, MOCK_URL, TEST_METHOD, TEST_TYPE, TEST_SIGNATURE_ALGORITHM
));
assertIsInternalServerError(e);
}
@ -213,7 +264,9 @@ public class RestClientTest {
public void testRuntimeExceptionCausesInternalServerError() {
when(httpClient.newRequest(anyString())).thenThrow(new RuntimeException());
ConnectRestException e = assertThrows(ConnectRestException.class, () -> httpRequest(httpClient));
ConnectRestException e = assertThrows(ConnectRestException.class, () -> httpRequest(
httpClient, MOCK_URL, TEST_METHOD, TEST_TYPE, TEST_SIGNATURE_ALGORITHM
));
assertIsInternalServerError(e);
}
@ -222,7 +275,9 @@ public class RestClientTest {
setupHttpClient(0, null);
String invalidRequestSignatureAlgorithm = "Foo";
ConnectRestException e = assertThrows(ConnectRestException.class, () -> httpRequest(httpClient, invalidRequestSignatureAlgorithm));
ConnectRestException e = assertThrows(ConnectRestException.class, () -> httpRequest(
httpClient, MOCK_URL, TEST_METHOD, TEST_TYPE, invalidRequestSignatureAlgorithm
));
assertIsInternalServerError(e);
}
@ -231,7 +286,9 @@ public class RestClientTest {
String invalidJsonString = "Invalid";
setupHttpClient(201, invalidJsonString);
ConnectRestException e = assertThrows(ConnectRestException.class, () -> httpRequest(httpClient));
ConnectRestException e = assertThrows(ConnectRestException.class, () -> httpRequest(
httpClient, MOCK_URL, TEST_METHOD, TEST_TYPE, TEST_SIGNATURE_ALGORITHM
));
assertIsInternalServerError(e);
}
@ -240,12 +297,15 @@ public class RestClientTest {
// See KAFKA-14816; we want to make sure that even if the worker is configured with invalid SSL properties,
// REST requests only fail if we try to contact a URL using HTTPS (but not HTTP)
int statusCode = Response.Status.OK.getStatusCode();
TestDTO expectedResponse = new TestDTO("someContent");
setupHttpClient(statusCode, toJsonString(expectedResponse));
setupHttpClient(statusCode, toJsonString(TEST_DTO));
String requestSignatureAlgorithm = "HmacSHA1";
assertDoesNotThrow(() -> httpRequest(httpClient, requestSignatureAlgorithm, false));
assertThrows(RuntimeException.class, () -> httpRequest(httpClient, requestSignatureAlgorithm, true));
assertDoesNotThrow(() -> httpRequest(
httpClient, MOCK_URL, TEST_METHOD, TEST_TYPE, TEST_SIGNATURE_ALGORITHM
));
String httpsUrl = "https://localhost:1234/api/endpoint";
assertThrows(RuntimeException.class, () -> httpRequest(
httpClient, httpsUrl, TEST_METHOD, TEST_TYPE, TEST_SIGNATURE_ALGORITHM
));
}
}

View File

@ -428,7 +428,7 @@ public class ConnectorsResourceTest {
expectAndCallbackNotLeaderException(cb).when(herder)
.deleteConnectorConfig(eq(CONNECTOR_NAME), cb.capture());
// Should forward request
when(restClient.httpRequest(LEADER_URL + "connectors/" + CONNECTOR_NAME + "?forward=false", "DELETE", NULL_HEADERS, null, null))
when(restClient.httpRequest(eq(LEADER_URL + "connectors/" + CONNECTOR_NAME + "?forward=false"), eq("DELETE"), isNull(), any(), any()))
.thenReturn(new RestClient.HttpResponse<>(204, new HashMap<>(), null));
connectorsResource.destroyConnector(CONNECTOR_NAME, NULL_HEADERS, FORWARD);
}