Add Ibm Granite Completion and Chat Completion support (#129146)

* Add Ibm Granite Completion and Chat Completion support

* Apply suggestions

* remove ibm watsonx transport version constant

* update transport version
This commit is contained in:
Evgenii-Kazannik 2025-07-02 22:57:16 +02:00 committed by GitHub
parent 82b6e45a81
commit 5d0c5e02bd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
36 changed files with 1395 additions and 189 deletions

View File

@ -0,0 +1,5 @@
pr: 129146
summary: "[ML] Add IBM watsonx Completion and Chat Completion support to the Inference Plugin"
area: Machine Learning
type: enhancement
issues: []

View File

@ -328,7 +328,7 @@ public class TransportVersions {
public static final TransportVersion MAPPINGS_IN_DATA_STREAMS = def(9_112_0_00);
public static final TransportVersion PROJECT_STATE_REGISTRY_RECORDS_DELETIONS = def(9_113_0_00);
public static final TransportVersion ESQL_SERIALIZE_TIMESERIES_FIELD_TYPE = def(9_114_0_00);
public static final TransportVersion ML_INFERENCE_IBM_WATSONX_COMPLETION_ADDED = def(9_115_0_00);
/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _

View File

@ -151,7 +151,8 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
"completion_test_service",
"hugging_face",
"amazon_sagemaker",
"mistral"
"mistral",
"watsonxai"
).toArray()
)
);
@ -169,7 +170,8 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
"hugging_face",
"amazon_sagemaker",
"googlevertexai",
"mistral"
"mistral",
"watsonxai"
).toArray()
)
);

View File

@ -95,6 +95,7 @@ import org.elasticsearch.xpack.inference.services.huggingface.completion.Hugging
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;
@ -469,6 +470,13 @@ public class InferenceNamedWriteablesProvider {
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, IbmWatsonxRerankTaskSettings.NAME, IbmWatsonxRerankTaskSettings::new)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
IbmWatsonxChatCompletionServiceSettings.NAME,
IbmWatsonxChatCompletionServiceSettings::new
)
);
}
private static void addGoogleVertexAiNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {

View File

@ -80,8 +80,8 @@ public class CohereRerankModel extends CohereModel {
/**
* Accepts a visitor to create an executable action. The returned action will not return documents in the response.
* @param visitor _
* @param taskSettings _
* @param visitor Interface for creating {@link ExecutableAction} instances for Cohere models.
* @param taskSettings Settings in the request to override the model's defaults
* @return the rerank action
*/
@Override

View File

@ -0,0 +1,51 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.ibmwatsonx;
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.response.IbmWatsonxErrorResponseEntity;
import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler;
import java.util.Locale;
/**
* Handles streaming chat completion responses and error parsing for Watsonx inference endpoints.
* Adapts the OpenAI handler to support Watsonx's error schema.
*/
public class IbmWatsonUnifiedChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler {
private static final String WATSONX_ERROR = "watsonx_error";
public IbmWatsonUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, IbmWatsonxErrorResponseEntity::fromResponse);
}
@Override
protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
assert request.isStreaming() : "Only streaming requests support this format";
var responseStatusCode = result.response().getStatusLine().getStatusCode();
if (request.isStreaming()) {
var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode);
var restStatus = toRestStatus(responseStatusCode);
return errorResponse instanceof IbmWatsonxErrorResponseEntity
? new UnifiedChatCompletionException(restStatus, errorMessage, WATSONX_ERROR, restStatus.name().toLowerCase(Locale.ROOT))
: new UnifiedChatCompletionException(
restStatus,
errorMessage,
createErrorType(errorResponse),
restStatus.name().toLowerCase(Locale.ROOT)
);
} else {
return super.buildError(message, request, result, errorResponse);
}
}
}

View File

@ -0,0 +1,25 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.ibmwatsonx;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.response.IbmWatsonxErrorResponseEntity;
import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler;
public class IbmWatsonxCompletionResponseHandler extends OpenAiChatCompletionResponseHandler {
/**
* Constructs a IbmWatsonxCompletionResponseHandler with the specified request type and response parser.
*
* @param requestType The type of request being handled (e.g., "IBM watsonx completions").
* @param parseFunction The function to parse the response.
*/
public IbmWatsonxCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, IbmWatsonxErrorResponseEntity::fromResponse);
}
}

View File

@ -35,7 +35,7 @@ public class IbmWatsonxEmbeddingsRequestManager extends IbmWatsonxRequestManager
private static final ResponseHandler HANDLER = createEmbeddingsHandler();
private static ResponseHandler createEmbeddingsHandler() {
return new IbmWatsonxResponseHandler("ibm watsonx embeddings", IbmWatsonxEmbeddingsResponseEntity::fromResponse);
return new IbmWatsonxResponseHandler("IBM watsonx embeddings", IbmWatsonxEmbeddingsResponseEntity::fromResponse);
}
private final IbmWatsonxEmbeddingsModel model;

View File

@ -7,18 +7,19 @@
package org.elasticsearch.xpack.inference.services.ibmwatsonx;
import org.elasticsearch.inference.Model;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.inference.TaskSettings;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.action.IbmWatsonxActionVisitor;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import java.util.Map;
import java.util.Objects;
public abstract class IbmWatsonxModel extends Model {
public abstract class IbmWatsonxModel extends RateLimitGroupingModel {
private final IbmWatsonxRateLimitServiceSettings rateLimitServiceSettings;
@ -49,4 +50,14 @@ public abstract class IbmWatsonxModel extends Model {
public IbmWatsonxRateLimitServiceSettings rateLimitServiceSettings() {
return rateLimitServiceSettings;
}
@Override
public int rateLimitGroupingHash() {
return Objects.hash(this.rateLimitServiceSettings);
}
@Override
public RateLimitSettings rateLimitSettings() {
return this.rateLimitServiceSettings().rateLimitSettings();
}
}

View File

@ -31,7 +31,7 @@ public class IbmWatsonxRerankRequestManager extends IbmWatsonxRequestManager {
private static ResponseHandler createIbmWatsonxResponseHandler() {
return new IbmWatsonxResponseHandler(
"ibm watsonx rerank",
"IBM watsonx rerank",
(request, response) -> IbmWatsonxRankedResponseEntity.fromResponse(response)
);
}

View File

@ -30,7 +30,10 @@ import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
@ -40,14 +43,18 @@ import org.elasticsearch.xpack.inference.services.SenderService;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ServiceUtils;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.action.IbmWatsonxActionCreator;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModel;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.request.IbmWatsonxChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
@ -56,7 +63,6 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersi
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
import static org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings.URL;
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.API_VERSION;
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.EMBEDDING_MAX_BATCH_SIZE;
@ -66,8 +72,16 @@ public class IbmWatsonxService extends SenderService {
public static final String NAME = "watsonxai";
private static final String SERVICE_NAME = "IBM Watsonx";
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING);
private static final String SERVICE_NAME = "IBM watsonx";
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
TaskType.TEXT_EMBEDDING,
TaskType.COMPLETION,
TaskType.CHAT_COMPLETION
);
private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new IbmWatsonUnifiedChatCompletionResponseHandler(
"IBM watsonx chat completions",
OpenAiChatCompletionResponseEntity::fromResponse
);
public IbmWatsonxService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
super(factory, serviceComponents);
@ -148,6 +162,14 @@ public class IbmWatsonxService extends SenderService {
secretSettings,
context
);
case CHAT_COMPLETION, COMPLETION -> new IbmWatsonxChatCompletionModel(
inferenceEntityId,
taskType,
NAME,
serviceSettings,
secretSettings,
context
);
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
};
}
@ -236,6 +258,11 @@ public class IbmWatsonxService extends SenderService {
return TransportVersions.V_8_16_0;
}
@Override
public Set<TaskType> supportedStreamingTasks() {
return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
}
@Override
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
if (model instanceof IbmWatsonxEmbeddingsModel embeddingsModel) {
@ -291,7 +318,24 @@ public class IbmWatsonxService extends SenderService {
TimeValue timeout,
ActionListener<InferenceServiceResults> listener
) {
throwUnsupportedUnifiedCompletionOperation(NAME);
if (model instanceof IbmWatsonxChatCompletionModel == false) {
listener.onFailure(createInvalidModelException(model));
return;
}
IbmWatsonxChatCompletionModel ibmWatsonxChatCompletionModel = (IbmWatsonxChatCompletionModel) model;
var overriddenModel = IbmWatsonxChatCompletionModel.of(ibmWatsonxChatCompletionModel, inputs.getRequest());
var manager = new GenericRequestManager<>(
getServiceComponents().threadPool(),
overriddenModel,
UNIFIED_CHAT_COMPLETION_HANDLER,
unifiedChatInput -> new IbmWatsonxChatCompletionRequest(unifiedChatInput, overriddenModel),
UnifiedChatInput.class
);
var errorMessage = IbmWatsonxActionCreator.buildErrorMessage(TaskType.CHAT_COMPLETION, model.getInferenceEntityId());
var action = new SenderExecutableAction(getSender(), manager, errorMessage);
action.execute(inputs, timeout, listener);
}
@Override
@ -331,7 +375,7 @@ public class IbmWatsonxService extends SenderService {
configurationMap.put(
API_VERSION,
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("The IBM Watsonx API version ID to use.")
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("The IBM watsonx API version ID to use.")
.setLabel("API Version")
.setRequired(true)
.setSensitive(false)

View File

@ -7,26 +7,48 @@
package org.elasticsearch.xpack.inference.services.ibmwatsonx.action;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.common.Truncator;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxCompletionResponseHandler;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxEmbeddingsRequestManager;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxRerankRequestManager;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModel;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.request.IbmWatsonxChatCompletionRequest;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
/**
* IbmWatsonxActionCreator is responsible for creating executable actions for various models.
* It implements the IbmWatsonxActionVisitor interface to provide specific implementations.
*/
public class IbmWatsonxActionCreator implements IbmWatsonxActionVisitor {
private final Sender sender;
private final ServiceComponents serviceComponents;
static final String COMPLETION_REQUEST_TYPE = "IBM watsonx completions";
static final String USER_ROLE = "user";
static final ResponseHandler COMPLETION_HANDLER = new IbmWatsonxCompletionResponseHandler(
COMPLETION_REQUEST_TYPE,
OpenAiChatCompletionResponseEntity::fromResponse
);
public IbmWatsonxActionCreator(Sender sender, ServiceComponents serviceComponents) {
this.sender = Objects.requireNonNull(sender);
this.serviceComponents = Objects.requireNonNull(serviceComponents);
@ -34,7 +56,7 @@ public class IbmWatsonxActionCreator implements IbmWatsonxActionVisitor {
@Override
public ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map<String, Object> taskSettings) {
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("IBM WatsonX embeddings");
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("IBM watsonx embeddings");
return new SenderExecutableAction(
sender,
getEmbeddingsRequestManager(model, serviceComponents.truncator(), serviceComponents.threadPool()),
@ -46,10 +68,24 @@ public class IbmWatsonxActionCreator implements IbmWatsonxActionVisitor {
public ExecutableAction create(IbmWatsonxRerankModel model, Map<String, Object> taskSettings) {
var overriddenModel = IbmWatsonxRerankModel.of(model, taskSettings);
var requestCreator = IbmWatsonxRerankRequestManager.of(overriddenModel, serviceComponents.threadPool());
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Ibm Watsonx rerank");
var failedToSendRequestErrorMessage = buildErrorMessage(TaskType.RERANK, overriddenModel.getInferenceEntityId());
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
}
@Override
public ExecutableAction create(IbmWatsonxChatCompletionModel chatCompletionModel) {
var manager = new GenericRequestManager<>(
serviceComponents.threadPool(),
chatCompletionModel,
COMPLETION_HANDLER,
inputs -> new IbmWatsonxChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), chatCompletionModel),
ChatCompletionInput.class
);
var failedToSendRequestErrorMessage = buildErrorMessage(TaskType.COMPLETION, chatCompletionModel.getInferenceEntityId());
return new SingleInputSenderExecutableAction(sender, manager, failedToSendRequestErrorMessage, COMPLETION_REQUEST_TYPE);
}
protected IbmWatsonxEmbeddingsRequestManager getEmbeddingsRequestManager(
IbmWatsonxEmbeddingsModel model,
Truncator truncator,
@ -57,4 +93,15 @@ public class IbmWatsonxActionCreator implements IbmWatsonxActionVisitor {
) {
return new IbmWatsonxEmbeddingsRequestManager(model, truncator, threadPool);
}
/**
* Builds an error message for IBM watsonx actions.
*
* @param requestType The type of request (e.g. COMPLETION, EMBEDDING, RERANK).
* @param inferenceId The ID of the inference entity.
* @return A formatted error message.
*/
public static String buildErrorMessage(TaskType requestType, String inferenceId) {
return format("Failed to send IBM watsonx %s request from inference entity id [%s]", requestType.toString(), inferenceId);
}
}

View File

@ -8,13 +8,42 @@
package org.elasticsearch.xpack.inference.services.ibmwatsonx.action;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModel;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;
import java.util.Map;
/**
* Interface for creating {@link ExecutableAction} instances for IBM watsonx models.
* <p>
* This interface is used to create {@link ExecutableAction} instances for different types of IBM watsonx models, such as
* {@link IbmWatsonxEmbeddingsModel} and {@link IbmWatsonxRerankModel} and {@link IbmWatsonxChatCompletionModel}.
*/
public interface IbmWatsonxActionVisitor {
/**
* Creates an {@link ExecutableAction} for the given {@link IbmWatsonxEmbeddingsModel}.
*
* @param model The model to create the action for.
* @param taskSettings The task settings to use.
* @return An {@link ExecutableAction} for the given model.
*/
ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map<String, Object> taskSettings);
/**
* Creates an {@link ExecutableAction} for the given {@link IbmWatsonxRerankModel}.
*
* @param model The model to create the action for.
* @return An {@link ExecutableAction} for the given model.
*/
ExecutableAction create(IbmWatsonxRerankModel model, Map<String, Object> taskSettings);
/**
* Creates an {@link ExecutableAction} for the given {@link IbmWatsonxChatCompletionModel}.
*
* @param model The model to create the action for.
* @return An {@link ExecutableAction} for the given model.
*/
ExecutableAction create(IbmWatsonxChatCompletionModel model);
}

View File

@ -0,0 +1,143 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.ibmwatsonx.completion;
import org.apache.http.client.utils.URIBuilder;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ModelSecrets;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxModel;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.action.IbmWatsonxActionVisitor;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.Map;
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.request.IbmWatsonxUtils.COMPLETIONS;
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.request.IbmWatsonxUtils.ML;
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.request.IbmWatsonxUtils.TEXT;
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.request.IbmWatsonxUtils.V1;
public class IbmWatsonxChatCompletionModel extends IbmWatsonxModel {
/**
* Constructor for IbmWatsonxChatCompletionModel.
*
* @param inferenceEntityId The unique identifier for the inference entity.
* @param taskType The type of task this model is designed for.
* @param service The name of the service this model belongs to.
* @param serviceSettings The settings specific to the Ibm Granite chat completion service.
* @param secrets The secrets required for accessing the service.
* @param context The context for parsing configuration settings.
*/
public IbmWatsonxChatCompletionModel(
String inferenceEntityId,
TaskType taskType,
String service,
Map<String, Object> serviceSettings,
@Nullable Map<String, Object> secrets,
ConfigurationParseContext context
) {
this(
inferenceEntityId,
taskType,
service,
IbmWatsonxChatCompletionServiceSettings.fromMap(serviceSettings, context),
DefaultSecretSettings.fromMap(secrets)
);
}
/**
* Creates a new IbmWatsonxChatCompletionModel with overridden service settings.
*
* @param model The original IbmWatsonxChatCompletionModel.
* @param request The UnifiedCompletionRequest containing the model override.
* @return A new IbmWatsonxChatCompletionModel with the overridden model ID.
*/
public static IbmWatsonxChatCompletionModel of(IbmWatsonxChatCompletionModel model, UnifiedCompletionRequest request) {
if (request.model() == null) {
// If no model is specified in the request, return the original model
return model;
}
var originalModelServiceSettings = model.getServiceSettings();
var overriddenServiceSettings = new IbmWatsonxChatCompletionServiceSettings(
originalModelServiceSettings.uri(),
originalModelServiceSettings.apiVersion(),
request.model(),
originalModelServiceSettings.projectId(),
originalModelServiceSettings.rateLimitSettings()
);
return new IbmWatsonxChatCompletionModel(
model.getInferenceEntityId(),
model.getTaskType(),
model.getConfigurations().getService(),
overriddenServiceSettings,
model.getSecretSettings()
);
}
// should only be used for testing
IbmWatsonxChatCompletionModel(
String inferenceEntityId,
TaskType taskType,
String service,
IbmWatsonxChatCompletionServiceSettings serviceSettings,
@Nullable DefaultSecretSettings secretSettings
) {
super(
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings),
new ModelSecrets(secretSettings),
serviceSettings
);
}
@Override
public IbmWatsonxChatCompletionServiceSettings getServiceSettings() {
return (IbmWatsonxChatCompletionServiceSettings) super.getServiceSettings();
}
@Override
public DefaultSecretSettings getSecretSettings() {
return (DefaultSecretSettings) super.getSecretSettings();
}
public URI uri() {
URI uri;
try {
uri = buildUri(this.getServiceSettings().uri().toString(), this.getServiceSettings().apiVersion());
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
return uri;
}
/**
* Accepts a visitor to create an executable action. The returned action will not return documents in the response.
* @param visitor Interface for creating {@link ExecutableAction} instances for IBM watsonx models.
* @return the completion action
*/
public ExecutableAction accept(IbmWatsonxActionVisitor visitor, Map<String, Object> taskSettings) {
return visitor.create(this);
}
public static URI buildUri(String uri, String apiVersion) throws URISyntaxException {
return new URIBuilder().setScheme("https")
.setHost(uri)
.setPathSegments(ML, V1, TEXT, COMPLETIONS)
.setParameter("version", apiVersion)
.build();
}
}

View File

@ -0,0 +1,193 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.ibmwatsonx.completion;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.ServiceSettings;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxRateLimitServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService;
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import java.io.IOException;
import java.net.URI;
import java.util.Map;
import java.util.Objects;
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri;
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.API_VERSION;
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.PROJECT_ID;
public class IbmWatsonxChatCompletionServiceSettings extends FilteredXContentObject
implements
ServiceSettings,
IbmWatsonxRateLimitServiceSettings {
public static final String NAME = "ibm_watsonx_completion_service_settings";
/**
* Rate limits are defined at
* <a href="https://www.ibm.com/docs/en/watsonx/saas?topic=learning-watson-machine-plans">Watson Machine Learning plans</a>.
* For the Lite plan, the limit is 120 requests per minute.
*/
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(120);
public static IbmWatsonxChatCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
ValidationException validationException = new ValidationException();
String url = extractRequiredString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
String apiVersion = extractRequiredString(map, API_VERSION, ModelConfigurations.SERVICE_SETTINGS, validationException);
String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
String projectId = extractRequiredString(map, PROJECT_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
validationException,
IbmWatsonxService.NAME,
context
);
if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
}
return new IbmWatsonxChatCompletionServiceSettings(uri, apiVersion, modelId, projectId, rateLimitSettings);
}
private final URI uri;
private final String apiVersion;
private final String modelId;
private final String projectId;
private final RateLimitSettings rateLimitSettings;
public IbmWatsonxChatCompletionServiceSettings(
URI uri,
String apiVersion,
String modelId,
String projectId,
@Nullable RateLimitSettings rateLimitSettings
) {
this.uri = uri;
this.apiVersion = apiVersion;
this.projectId = projectId;
this.modelId = modelId;
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
}
public IbmWatsonxChatCompletionServiceSettings(StreamInput in) throws IOException {
this.uri = createUri(in.readString());
this.apiVersion = in.readString();
this.modelId = in.readString();
this.projectId = in.readString();
this.rateLimitSettings = new RateLimitSettings(in);
}
public URI uri() {
return uri;
}
public String apiVersion() {
return apiVersion;
}
@Override
public String modelId() {
return modelId;
}
public String projectId() {
return projectId;
}
@Override
public RateLimitSettings rateLimitSettings() {
return rateLimitSettings;
}
@Override
public String getWriteableName() {
return NAME;
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
toXContentFragmentOfExposedFields(builder, params);
builder.endObject();
return builder;
}
@Override
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
builder.field(URL, uri.toString());
builder.field(API_VERSION, apiVersion);
builder.field(MODEL_ID, modelId);
builder.field(PROJECT_ID, projectId);
rateLimitSettings.toXContent(builder, params);
return builder;
}
@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersions.ML_INFERENCE_IBM_WATSONX_COMPLETION_ADDED;
}
@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(uri.toString());
out.writeString(apiVersion);
out.writeString(modelId);
out.writeString(projectId);
rateLimitSettings.writeTo(out);
}
@Override
public boolean equals(Object object) {
if (this == object) return true;
if (object == null || getClass() != object.getClass()) return false;
IbmWatsonxChatCompletionServiceSettings that = (IbmWatsonxChatCompletionServiceSettings) object;
return Objects.equals(uri, that.uri)
&& Objects.equals(apiVersion, that.apiVersion)
&& Objects.equals(modelId, that.modelId)
&& Objects.equals(projectId, that.projectId)
&& Objects.equals(rateLimitSettings, that.rateLimitSettings);
}
@Override
public int hashCode() {
return Objects.hash(uri, apiVersion, modelId, projectId, rateLimitSettings);
}
}

View File

@ -52,7 +52,7 @@ public class IbmWatsonxEmbeddingsServiceSettings extends FilteredXContentObject
/**
* Rate limits are defined at
* <a href="https://www.ibm.com/docs/en/watsonx/saas?topic=learning-watson-machine-plans">Watson Machine Learning plans</a>.
* For Lite plan, you've 120 requests per minute.
* For the Lite plan, the limit is 120 requests per minute.
*/
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(120);

View File

@ -0,0 +1,79 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.ibmwatsonx.request;
import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ByteArrayEntity;
import org.elasticsearch.common.Strings;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModel;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.util.Objects;
public class IbmWatsonxChatCompletionRequest implements IbmWatsonxRequest {
private final IbmWatsonxChatCompletionModel model;
private final UnifiedChatInput chatInput;
public IbmWatsonxChatCompletionRequest(UnifiedChatInput chatInput, IbmWatsonxChatCompletionModel model) {
this.chatInput = Objects.requireNonNull(chatInput);
this.model = Objects.requireNonNull(model);
}
@Override
public HttpRequest createHttpRequest() {
HttpPost httpPost = new HttpPost(model.uri());
ByteArrayEntity byteEntity = new ByteArrayEntity(
Strings.toString(new IbmWatsonxChatCompletionRequestEntity(chatInput, model)).getBytes(StandardCharsets.UTF_8)
);
httpPost.setEntity(byteEntity);
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
decorateWithAuth(httpPost);
return new HttpRequest(httpPost, getInferenceEntityId());
}
@Override
public URI getURI() {
return model.uri();
}
public void decorateWithAuth(HttpPost httpPost) {
IbmWatsonxRequest.decorateWithBearerToken(httpPost, model.getSecretSettings(), model.getInferenceEntityId());
}
@Override
public Request truncate() {
// No truncation for IBM watsonx chat completions
return this;
}
@Override
public boolean[] getTruncationInfo() {
// No truncation for IBM watsonx chat completions
return null;
}
@Override
public String getInferenceEntityId() {
return model.getInferenceEntityId();
}
@Override
public boolean isStreaming() {
return chatInput.stream();
}
}

View File

@ -0,0 +1,47 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.ibmwatsonx.request;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModel;
import java.io.IOException;
import java.util.Objects;
/**
* IbmWatsonxChatCompletionRequestEntity is responsible for creating the request entity for Watsonx chat completion.
* It implements ToXContentObject to allow serialization to XContent format.
*/
public class IbmWatsonxChatCompletionRequestEntity implements ToXContentObject {
private final IbmWatsonxChatCompletionModel model;
private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;
private static final String PROJECT_ID_FIELD = "project_id";
public IbmWatsonxChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, IbmWatsonxChatCompletionModel model) {
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput);
this.model = Objects.requireNonNull(model);
}
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(PROJECT_ID_FIELD, model.getServiceSettings().projectId());
unifiedRequestEntity.toXContent(
builder,
UnifiedCompletionRequest.withMaxTokensAndSkipStreamOptionsField(model.getServiceSettings().modelId(), params)
);
builder.endObject();
return builder;
}
}

View File

@ -14,6 +14,7 @@ public class IbmWatsonxUtils {
public static final String TEXT = "text";
public static final String EMBEDDINGS = "embeddings";
public static final String RERANKS = "reranks";
public static final String COMPLETIONS = "chat";
private IbmWatsonxUtils() {}

View File

@ -100,8 +100,8 @@ public class IbmWatsonxRerankModel extends IbmWatsonxModel {
/**
* Accepts a visitor to create an executable action. The returned action will not return documents in the response.
* @param visitor _
* @param taskSettings _
* @param visitor Interface for creating {@link ExecutableAction} instances for IBM watsonx models.
* @param taskSettings Settings in the request to override the model's defaults
* @return the rerank action
*/
@Override

View File

@ -41,7 +41,7 @@ public class IbmWatsonxRerankServiceSettings extends FilteredXContentObject impl
/**
* Rate limits are defined at
* <a href="https://www.ibm.com/docs/en/watsonx/saas?topic=learning-watson-machine-plans">Watson Machine Learning plans</a>.
* For Lite plan, you've 120 requests per minute.
* For the Lite plan, the limit is 120 requests per minute.
*/
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(120);

View File

@ -28,7 +28,7 @@ import static org.elasticsearch.xpack.inference.external.response.XContentUtils.
public class IbmWatsonxEmbeddingsResponseEntity {
private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in IBM Watsonx embeddings response";
private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in IBM watsonx embeddings response";
public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);

View File

@ -32,7 +32,7 @@ public class IbmWatsonxRankedResponseEntity {
private static final Logger logger = LogManager.getLogger(IbmWatsonxRankedResponseEntity.class);
/**
* Parses the Ibm Watsonx ranked response.
* Parses the IBM watsonx ranked response.
*
* For a request like:
* "model": "rerank-english-v2.0",
@ -71,7 +71,7 @@ public class IbmWatsonxRankedResponseEntity {
* ],
* }
*
* @param response the http response from ibm watsonx
* @param response the http response from IBM watsonx
* @return the parsed response
* @throws IOException if there is an error parsing the response
*/

View File

@ -84,8 +84,8 @@ public class JinaAIRerankModel extends JinaAIModel {
/**
* Accepts a visitor to create an executable action. The returned action will not return documents in the response.
* @param visitor _
* @param taskSettings _
* @param visitor Interface for creating {@link ExecutableAction} instances for Jina AI models.
* @param taskSettings Settings in the request to override the model's defaults
* @return the rerank action
*/
@Override

View File

@ -37,8 +37,8 @@ import static org.elasticsearch.core.Strings.format;
public class MistralActionCreator implements MistralActionVisitor {
public static final String COMPLETION_ERROR_PREFIX = "Mistral completions";
static final String USER_ROLE = "user";
static final ResponseHandler COMPLETION_HANDLER = new MistralCompletionResponseHandler(
public static final String USER_ROLE = "user";
public static final ResponseHandler COMPLETION_HANDLER = new MistralCompletionResponseHandler(
"mistral completions",
OpenAiChatCompletionResponseEntity::fromResponse
);

View File

@ -109,8 +109,8 @@ public class VoyageAIRerankModel extends VoyageAIModel {
/**
* Accepts a visitor to create an executable action. The returned action will not return documents in the response.
* @param visitor _
* @param taskSettings _
* @param visitor Interface for creating {@link ExecutableAction} instances for Voyage AI models.
* @param taskSettings Settings in the request to override the model's defaults
* @return the rerank action
*/
@Override

View File

@ -0,0 +1,154 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.junit.After;
import org.junit.Before;
import java.io.IOException;
import java.net.URISyntaxException;
import java.util.List;
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
import static org.hamcrest.Matchers.is;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
public abstract class ChatCompletionActionTests extends ESTestCase {
protected static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
protected final MockWebServer webServer = new MockWebServer();
protected HttpClientManager clientManager;
protected ThreadPool threadPool;
protected abstract ExecutableAction createAction(String url, Sender sender) throws URISyntaxException;
protected abstract String getOneInputError();
protected abstract String getFailedToSendError();
@Before
public void init() throws Exception {
webServer.start();
threadPool = createThreadPool(inferenceUtilityPool());
clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
}
@After
public void shutdown() throws IOException {
clientManager.close();
terminate(threadPool);
webServer.close();
}
public void testExecute_ThrowsElasticsearchException() throws URISyntaxException {
var sender = mock(Sender.class);
doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any());
var action = createAction(getUrl(webServer), sender);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
assertThat(thrownException.getMessage(), is("failed"));
}
public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() throws URISyntaxException {
var sender = mock(Sender.class);
doAnswer(invocation -> {
ActionListener<InferenceServiceResults> listener = invocation.getArgument(3);
listener.onFailure(new IllegalStateException("failed"));
return Void.TYPE;
}).when(sender).send(any(), any(), any(), any());
var action = createAction(getUrl(webServer), sender);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
assertThat(thrownException.getMessage(), is(getFailedToSendError()));
}
public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException, URISyntaxException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = createSender(senderFactory)) {
sender.start();
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(getResponseJson()));
var action = createAction(getUrl(webServer), sender);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
assertThat(thrownException.getMessage(), is(getOneInputError()));
assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST));
}
}
protected String getResponseJson() {
return """
{
"id": "9d80f26810ac4e9582f927fcf0512ec7",
"object": "chat.completion",
"created": 1748596419,
"model": "modelId",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"tool_calls": null,
"content": "result content"
},
"finish_reason": "length",
"logprobs": null
}
],
"usage": {
"prompt_tokens": 10,
"total_tokens": 11,
"completion_tokens": 1
}
}
""";
}
}

View File

@ -918,8 +918,8 @@ public class IbmWatsonxServiceTests extends ESTestCase {
String content = XContentHelper.stripWhitespace("""
{
"service": "watsonxai",
"name": "IBM Watsonx",
"task_types": ["text_embedding"],
"name": "IBM watsonx",
"task_types": ["text_embedding", "completion", "chat_completion"],
"configurations": {
"project_id": {
"description": "",
@ -928,7 +928,7 @@ public class IbmWatsonxServiceTests extends ESTestCase {
"sensitive": false,
"updatable": false,
"type": "str",
"supported_task_types": ["text_embedding"]
"supported_task_types": ["text_embedding", "completion", "chat_completion"]
},
"model_id": {
"description": "The name of the model to use for the inference task.",
@ -937,16 +937,16 @@ public class IbmWatsonxServiceTests extends ESTestCase {
"sensitive": false,
"updatable": false,
"type": "str",
"supported_task_types": ["text_embedding"]
"supported_task_types": ["text_embedding", "completion", "chat_completion"]
},
"api_version": {
"description": "The IBM Watsonx API version ID to use.",
"description": "The IBM watsonx API version ID to use.",
"label": "API Version",
"required": true,
"sensitive": false,
"updatable": false,
"type": "str",
"supported_task_types": ["text_embedding"]
"supported_task_types": ["text_embedding", "completion", "chat_completion"]
},
"max_input_tokens": {
"description": "Allows you to specify the maximum number of tokens per input.",
@ -964,7 +964,7 @@ public class IbmWatsonxServiceTests extends ESTestCase {
"sensitive": false,
"updatable": false,
"type": "str",
"supported_task_types": ["text_embedding"]
"supported_task_types": ["text_embedding", "completion", "chat_completion"]
}
}
}

View File

@ -0,0 +1,50 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.ibmwatsonx.action;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.ChatCompletionActionTests;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.request.IbmWatsonxChatCompletionRequest;
import java.net.URI;
import java.net.URISyntaxException;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.action.IbmWatsonxActionCreator.COMPLETION_HANDLER;
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.action.IbmWatsonxActionCreator.USER_ROLE;
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModelTests.createModel;
public class IbmWatsonxChatCompletionActionTests extends ChatCompletionActionTests {
public static final URI TEST_URI = URI.create("abc.com");
protected ExecutableAction createAction(String url, Sender sender) throws URISyntaxException {
var model = createModel(TEST_URI, randomAlphaOfLength(8), randomAlphaOfLength(8), randomAlphaOfLength(8), randomAlphaOfLength(8));
var manager = new GenericRequestManager<>(
threadPool,
model,
COMPLETION_HANDLER,
inputs -> new IbmWatsonxChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model),
ChatCompletionInput.class
);
var errorMessage = constructFailedToSendRequestMessage("watsonx chat completions");
return new SingleInputSenderExecutableAction(sender, manager, errorMessage, "watsonx chat completions");
}
protected String getFailedToSendError() {
return "Failed to send watsonx chat completions request. Cause: failed";
}
protected String getOneInputError() {
return "watsonx chat completions only accepts 1 input";
}
}

View File

@ -180,7 +180,7 @@ public class IbmWatsonxEmbeddingsActionTests extends ESTestCase {
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
assertThat(thrownException.getMessage(), is("Failed to send IBM Watsonx embeddings request. Cause: failed"));
assertThat(thrownException.getMessage(), is("Failed to send IBM watsonx embeddings request. Cause: failed"));
}
public void testExecute_ThrowsException() {
@ -204,7 +204,7 @@ public class IbmWatsonxEmbeddingsActionTests extends ESTestCase {
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
assertThat(thrownException.getMessage(), is("Failed to send IBM Watsonx embeddings request. Cause: failed"));
assertThat(thrownException.getMessage(), is("Failed to send IBM watsonx embeddings request. Cause: failed"));
}
private ExecutableAction createAction(
@ -218,7 +218,7 @@ public class IbmWatsonxEmbeddingsActionTests extends ESTestCase {
) {
var model = createModel(modelName, projectId, uri, apiVersion, apiKey, url);
var requestManager = new IbmWatsonxEmbeddingsRequestManagerWithoutAuth(model, TruncatorTests.createTruncator(), threadPool);
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("IBM Watsonx embeddings");
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("IBM watsonx embeddings");
return new SenderExecutableAction(sender, requestManager, failedToSendRequestErrorMessage);
}

View File

@ -0,0 +1,107 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.ibmwatsonx.completion;
import org.elasticsearch.common.settings.SecureString;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.List;
import static org.hamcrest.Matchers.is;
public class IbmWatsonxChatCompletionModelTests extends ESTestCase {
private static final URI TEST_URI = URI.create("abc.com");
public static IbmWatsonxChatCompletionModel createModel(URI uri, String apiVersion, String modelId, String projectId, String apiKey)
throws URISyntaxException {
return new IbmWatsonxChatCompletionModel(
"id",
TaskType.COMPLETION,
"service",
new IbmWatsonxChatCompletionServiceSettings(uri, apiVersion, modelId, projectId, null),
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
);
}
public void testOverrideWith_UnifiedCompletionRequest_OverridesExistingModelId() throws URISyntaxException {
var model = createModel(TEST_URI, "apiVersion", "modelId", "projectId", "apiKey");
var request = new UnifiedCompletionRequest(
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
"different_model",
null,
null,
null,
null,
null,
null
);
var overriddenModel = IbmWatsonxChatCompletionModel.of(model, request);
assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model"));
}
public void testOverrideWith_UnifiedCompletionRequest_OverridesNullModelId() throws URISyntaxException {
var model = createModel(TEST_URI, "apiVersion", null, "projectId", "apiKey");
var request = new UnifiedCompletionRequest(
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
"different_model",
null,
null,
null,
null,
null,
null
);
var overriddenModel = IbmWatsonxChatCompletionModel.of(model, request);
assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model"));
}
public void testOverrideWith_UnifiedCompletionRequest_KeepsNullIfNoModelIdProvided() throws URISyntaxException {
var model = createModel(TEST_URI, "apiVersion", null, "projectId", "apiKey");
var request = new UnifiedCompletionRequest(
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
null,
null,
null,
null,
null,
null,
null
);
var overriddenModel = IbmWatsonxChatCompletionModel.of(model, request);
assertNull(overriddenModel.getServiceSettings().modelId());
}
public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() throws URISyntaxException {
var model = createModel(TEST_URI, "apiVersion", "modelId", "projectId", "apiKey");
var request = new UnifiedCompletionRequest(
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
null, // not overriding model
null,
null,
null,
null,
null,
null
);
var overriddenModel = IbmWatsonxChatCompletionModel.of(model, request);
assertThat(overriddenModel.getServiceSettings().modelId(), is("modelId"));
}
}

View File

@ -0,0 +1,173 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.ibmwatsonx.completion;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.ValidationException;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.ServiceFields;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
import java.io.IOException;
import java.net.URI;
import java.util.HashMap;
import java.util.Map;
import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
public class IbmWatsonxChatCompletionServiceSettingsTests extends AbstractWireSerializingTestCase<IbmWatsonxChatCompletionServiceSettings> {
private static final URI TEST_URI = URI.create("abc.com");
private static IbmWatsonxChatCompletionServiceSettings createRandom() {
return new IbmWatsonxChatCompletionServiceSettings(
TEST_URI,
randomAlphaOfLength(8),
randomAlphaOfLength(8),
randomAlphaOfLength(8),
randomFrom(RateLimitSettingsTests.createRandom(), null)
);
}
private IbmWatsonxChatCompletionServiceSettings getServiceSettings(Map<String, String> map) {
return IbmWatsonxChatCompletionServiceSettings.fromMap(new HashMap<>(map), ConfigurationParseContext.PERSISTENT);
}
public void testFromMap_WithAllParameters_CreatesSettingsCorrectly() {
var model = randomAlphaOfLength(8);
var projectId = randomAlphaOfLength(8);
var apiVersion = randomAlphaOfLength(8);
var serviceSettings = getServiceSettings(
Map.of(
ServiceFields.URL,
TEST_URI.toString(),
IbmWatsonxServiceFields.API_VERSION,
apiVersion,
ServiceFields.MODEL_ID,
model,
IbmWatsonxServiceFields.PROJECT_ID,
projectId
)
);
assertThat(serviceSettings, is(new IbmWatsonxChatCompletionServiceSettings(TEST_URI, apiVersion, model, projectId, null)));
}
public void testFromMap_Fails_WithoutRequiredParam_Url() {
var ex = expectThrows(
ValidationException.class,
() -> getServiceSettings(
Map.of(
IbmWatsonxServiceFields.API_VERSION,
randomAlphaOfLength(8),
ServiceFields.MODEL_ID,
randomAlphaOfLength(8),
IbmWatsonxServiceFields.PROJECT_ID,
randomAlphaOfLength(8)
)
)
);
assertThat(ex.getMessage(), equalTo(generateErrorMessage("url")));
}
public void testFromMap_Fails_WithoutRequiredParam_ApiVersion() {
var ex = expectThrows(
ValidationException.class,
() -> getServiceSettings(
Map.of(
ServiceFields.URL,
TEST_URI.toString(),
ServiceFields.MODEL_ID,
randomAlphaOfLength(8),
IbmWatsonxServiceFields.PROJECT_ID,
randomAlphaOfLength(8)
)
)
);
assertThat(ex.getMessage(), equalTo(generateErrorMessage("api_version")));
}
public void testFromMap_Fails_WithoutRequiredParam_ModelId() {
var ex = expectThrows(
ValidationException.class,
() -> getServiceSettings(
Map.of(
ServiceFields.URL,
TEST_URI.toString(),
IbmWatsonxServiceFields.API_VERSION,
randomAlphaOfLength(8),
IbmWatsonxServiceFields.PROJECT_ID,
randomAlphaOfLength(8)
)
)
);
assertThat(ex.getMessage(), equalTo(generateErrorMessage("model_id")));
}
public void testFromMap_Fails_WithoutRequiredParam_ProjectId() {
var ex = expectThrows(
ValidationException.class,
() -> getServiceSettings(
Map.of(
ServiceFields.URL,
TEST_URI.toString(),
IbmWatsonxServiceFields.API_VERSION,
randomAlphaOfLength(8),
ServiceFields.MODEL_ID,
randomAlphaOfLength(8)
)
)
);
assertThat(ex.getMessage(), equalTo(generateErrorMessage("project_id")));
}
private String generateErrorMessage(String field) {
return "Validation Failed: 1: [service_settings] does not contain the required setting [" + field + "];";
}
public void testToXContent_WritesAllValues() throws IOException {
var entity = new IbmWatsonxChatCompletionServiceSettings(TEST_URI, "2024-05-02", "model", "project_id", null);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
entity.toXContent(builder, null);
String xContentResult = Strings.toString(builder);
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
{
"url":"abc.com",
"api_version":"2024-05-02",
"model_id":"model",
"project_id":"project_id",
"rate_limit": {
"requests_per_minute":120
}
}"""));
}
@Override
protected Writeable.Reader<IbmWatsonxChatCompletionServiceSettings> instanceReader() {
return IbmWatsonxChatCompletionServiceSettings::new;
}
@Override
protected IbmWatsonxChatCompletionServiceSettings createTestInstance() {
return createRandom();
}
@Override
protected IbmWatsonxChatCompletionServiceSettings mutateInstance(IbmWatsonxChatCompletionServiceSettings instance) throws IOException {
return randomValueOtherThan(instance, IbmWatsonxChatCompletionServiceSettingsTests::createRandom);
}
}

View File

@ -0,0 +1,66 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.ibmwatsonx.request;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.inference.UnifiedCompletionRequest;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xcontent.ToXContent;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.json.JsonXContent;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModel;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModelTests.createModel;
public class IbmWatsonxChatCompletionRequestEntityTests extends ESTestCase {
private static final String ROLE = "user";
public void testModelUserFieldsSerialization() throws IOException, URISyntaxException {
UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message(
new UnifiedCompletionRequest.ContentString("test content"),
ROLE,
null,
null
);
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
messageList.add(message);
var unifiedRequest = UnifiedCompletionRequest.of(messageList);
UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true);
IbmWatsonxChatCompletionModel model = createModel(new URI("abc.com"), "apiVersion", "modelId", "projectId", "apiKey");
IbmWatsonxChatCompletionRequestEntity entity = new IbmWatsonxChatCompletionRequestEntity(unifiedChatInput, model);
XContentBuilder builder = JsonXContent.contentBuilder();
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
String expectedJson = """
{
"project_id": "projectId",
"messages": [
{
"content": "test content",
"role": "user"
}
],
"model": "modelId",
"n": 1,
"stream": true
}
""";
assertEquals(XContentHelper.stripWhitespace(expectedJson), Strings.toString(builder));
}
}

View File

@ -0,0 +1,106 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/
package org.elasticsearch.xpack.inference.services.ibmwatsonx.request;
import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModel;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModelTests;
import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.List;
import java.util.Map;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
import static org.hamcrest.Matchers.aMapWithSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
public class IbmWatsonxChatCompletionRequestTests extends ESTestCase {
private static final String AUTH_HEADER_VALUE = "foo";
private static final String API_COMPLETIONS_PATH = "https://abc.com/ml/v1/text/chat?version=apiVersion";
public void testCreateRequest_WithStreaming() throws IOException, URISyntaxException {
assertCreateRequestWithStreaming(true);
}
public void testCreateRequest_WithNoStreaming() throws IOException, URISyntaxException {
assertCreateRequestWithStreaming(false);
}
public void testTruncate_DoesNotReduceInputTextSize() throws IOException, URISyntaxException {
String input = randomAlphaOfLength(5);
String model = randomAlphaOfLength(5);
var request = createRequest(randomAlphaOfLength(5), input, model, true);
var truncatedRequest = request.truncate();
assertThat(request.getURI().toString(), is(API_COMPLETIONS_PATH));
var httpRequest = truncatedRequest.createHttpRequest();
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
var httpPost = (HttpPost) httpRequest.httpRequestBase();
var requestMap = entityAsMap(httpPost.getEntity().getContent());
assertThat(requestMap, aMapWithSize(5));
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input))));
assertThat(requestMap.get("model"), is(model));
assertThat(requestMap.get("n"), is(1));
assertTrue((Boolean) requestMap.get("stream"));
assertNull(requestMap.get("stream_options"));
}
public void testTruncationInfo_ReturnsNull() throws URISyntaxException {
var request = createRequest(randomAlphaOfLength(5), randomAlphaOfLength(5), randomAlphaOfLength(5), true);
assertNull(request.getTruncationInfo());
}
public static IbmWatsonxChatCompletionRequest createRequest(String apiKey, String input, @Nullable String model)
throws URISyntaxException {
return createRequest(apiKey, input, model, false);
}
public static IbmWatsonxChatCompletionRequest createRequest(String apiKey, String input, @Nullable String model, boolean stream)
throws URISyntaxException {
var chatCompletionModel = IbmWatsonxChatCompletionModelTests.createModel(
new URI("abc.com"),
"apiVersion",
model,
randomAlphaOfLength(5),
apiKey
);
return new IbmWatsonxChatCompletionWithoutAuthRequest(new UnifiedChatInput(List.of(input), "user", stream), chatCompletionModel);
}
private static class IbmWatsonxChatCompletionWithoutAuthRequest extends IbmWatsonxChatCompletionRequest {
IbmWatsonxChatCompletionWithoutAuthRequest(UnifiedChatInput input, IbmWatsonxChatCompletionModel model) {
super(input, model);
}
@Override
public void decorateWithAuth(HttpPost httpPost) {
httpPost.setHeader(HttpHeaders.AUTHORIZATION, AUTH_HEADER_VALUE);
}
}
private void assertCreateRequestWithStreaming(boolean isStreaming) throws URISyntaxException, IOException {
var request = createRequest(randomAlphaOfLength(5), randomAlphaOfLength(5), randomAlphaOfLength(5), isStreaming);
var httpRequest = request.createHttpRequest();
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
var httpPost = (HttpPost) httpRequest.httpRequestBase();
var requestMap = entityAsMap(httpPost.getEntity().getContent());
assertThat(requestMap.get("stream"), is(isStreaming));
}
}

View File

@ -112,6 +112,6 @@ public class IbmWatsonxEmbeddingsResponseEntityTests extends ESTestCase {
)
);
assertThat(thrownException.getMessage(), is("Failed to find required field [results] in IBM Watsonx embeddings response"));
assertThat(thrownException.getMessage(), is("Failed to find required field [results] in IBM watsonx embeddings response"));
}
}

View File

@ -8,42 +8,28 @@
package org.elasticsearch.xpack.inference.services.mistral.action;
import org.apache.http.HttpHeaders;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.http.MockRequest;
import org.elasticsearch.test.http.MockResponse;
import org.elasticsearch.test.http.MockWebServer;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.services.ChatCompletionActionTests;
import org.elasticsearch.xpack.inference.services.mistral.request.completion.MistralChatCompletionRequest;
import org.junit.After;
import org.junit.Before;
import java.io.IOException;
import java.net.URISyntaxException;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
@ -56,64 +42,15 @@ import static org.elasticsearch.xpack.inference.services.mistral.completion.Mist
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.is;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
public class MistralChatCompletionActionTests extends ESTestCase {
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
private final MockWebServer webServer = new MockWebServer();
private ThreadPool threadPool;
private HttpClientManager clientManager;
@Before
public void init() throws Exception {
webServer.start();
threadPool = createThreadPool(inferenceUtilityPool());
clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
}
@After
public void shutdown() throws IOException {
clientManager.close();
terminate(threadPool);
webServer.close();
}
public void testExecute_ReturnsSuccessfulResponse() throws IOException {
public class MistralChatCompletionActionTests extends ChatCompletionActionTests {
public void testExecute_ReturnsSuccessfulResponse() throws IOException, URISyntaxException {
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
{
"id": "9d80f26810ac4e9582f927fcf0512ec7",
"object": "chat.completion",
"created": 1748596419,
"model": "mistral-small-latest",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"tool_calls": null,
"content": "result content"
},
"finish_reason": "length",
"logprobs": null
}
],
"usage": {
"prompt_tokens": 10,
"total_tokens": 11,
"completion_tokens": 1
}
}
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(getResponseJson()));
var action = createAction(getUrl(webServer), sender);
@ -140,87 +77,7 @@ public class MistralChatCompletionActionTests extends ESTestCase {
}
}
public void testExecute_ThrowsElasticsearchException() {
var sender = mock(Sender.class);
doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any());
var action = createAction(getUrl(webServer), sender);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
assertThat(thrownException.getMessage(), is("failed"));
}
public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() {
var sender = mock(Sender.class);
doAnswer(invocation -> {
ActionListener<InferenceServiceResults> listener = invocation.getArgument(3);
listener.onFailure(new IllegalStateException("failed"));
return Void.TYPE;
}).when(sender).send(any(), any(), any(), any());
var action = createAction(getUrl(webServer), sender);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
assertThat(thrownException.getMessage(), is("Failed to send mistral chat completions request. Cause: failed"));
}
public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException {
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
try (var sender = createSender(senderFactory)) {
sender.start();
String responseJson = """
{
"id": "9d80f26810ac4e9582f927fcf0512ec7",
"object": "chat.completion",
"created": 1748596419,
"model": "mistral-small-latest",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"tool_calls": null,
"content": "result content"
},
"finish_reason": "length",
"logprobs": null
}
],
"usage": {
"prompt_tokens": 10,
"total_tokens": 11,
"completion_tokens": 1
}
}
""";
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
var action = createAction(getUrl(webServer), sender);
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
assertThat(thrownException.getMessage(), is("mistral chat completions only accepts 1 input"));
assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST));
}
}
private ExecutableAction createAction(String url, Sender sender) {
protected ExecutableAction createAction(String url, Sender sender) {
var model = createCompletionModel("secret", "model");
model.setURI(url);
var manager = new GenericRequestManager<>(
@ -233,4 +90,12 @@ public class MistralChatCompletionActionTests extends ESTestCase {
var errorMessage = constructFailedToSendRequestMessage("mistral chat completions");
return new SingleInputSenderExecutableAction(sender, manager, errorMessage, "mistral chat completions");
}
protected String getFailedToSendError() {
return "Failed to send mistral chat completions request. Cause: failed";
}
protected String getOneInputError() {
return "mistral chat completions only accepts 1 input";
}
}