Add Ibm Granite Completion and Chat Completion support (#129146)
* Add Ibm Granite Completion and Chat Completion support * Apply suggestions * remove ibm watsonx transport version constant * update transport version
This commit is contained in:
parent
82b6e45a81
commit
5d0c5e02bd
|
@ -0,0 +1,5 @@
|
|||
pr: 129146
|
||||
summary: "[ML] Add IBM watsonx Completion and Chat Completion support to the Inference Plugin"
|
||||
area: Machine Learning
|
||||
type: enhancement
|
||||
issues: []
|
|
@ -328,7 +328,7 @@ public class TransportVersions {
|
|||
public static final TransportVersion MAPPINGS_IN_DATA_STREAMS = def(9_112_0_00);
|
||||
public static final TransportVersion PROJECT_STATE_REGISTRY_RECORDS_DELETIONS = def(9_113_0_00);
|
||||
public static final TransportVersion ESQL_SERIALIZE_TIMESERIES_FIELD_TYPE = def(9_114_0_00);
|
||||
|
||||
public static final TransportVersion ML_INFERENCE_IBM_WATSONX_COMPLETION_ADDED = def(9_115_0_00);
|
||||
/*
|
||||
* STOP! READ THIS FIRST! No, really,
|
||||
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
|
||||
|
|
|
@ -151,7 +151,8 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
|||
"completion_test_service",
|
||||
"hugging_face",
|
||||
"amazon_sagemaker",
|
||||
"mistral"
|
||||
"mistral",
|
||||
"watsonxai"
|
||||
).toArray()
|
||||
)
|
||||
);
|
||||
|
@ -169,7 +170,8 @@ public class InferenceGetServicesIT extends BaseMockEISAuthServerTest {
|
|||
"hugging_face",
|
||||
"amazon_sagemaker",
|
||||
"googlevertexai",
|
||||
"mistral"
|
||||
"mistral",
|
||||
"watsonxai"
|
||||
).toArray()
|
||||
)
|
||||
);
|
||||
|
|
|
@ -95,6 +95,7 @@ import org.elasticsearch.xpack.inference.services.huggingface.completion.Hugging
|
|||
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.huggingface.rerank.HuggingFaceRerankTaskSettings;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;
|
||||
|
@ -469,6 +470,13 @@ public class InferenceNamedWriteablesProvider {
|
|||
namedWriteables.add(
|
||||
new NamedWriteableRegistry.Entry(TaskSettings.class, IbmWatsonxRerankTaskSettings.NAME, IbmWatsonxRerankTaskSettings::new)
|
||||
);
|
||||
namedWriteables.add(
|
||||
new NamedWriteableRegistry.Entry(
|
||||
ServiceSettings.class,
|
||||
IbmWatsonxChatCompletionServiceSettings.NAME,
|
||||
IbmWatsonxChatCompletionServiceSettings::new
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
private static void addGoogleVertexAiNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
|
||||
|
|
|
@ -80,8 +80,8 @@ public class CohereRerankModel extends CohereModel {
|
|||
|
||||
/**
|
||||
* Accepts a visitor to create an executable action. The returned action will not return documents in the response.
|
||||
* @param visitor _
|
||||
* @param taskSettings _
|
||||
* @param visitor Interface for creating {@link ExecutableAction} instances for Cohere models.
|
||||
* @param taskSettings Settings in the request to override the model's defaults
|
||||
* @return the rerank action
|
||||
*/
|
||||
@Override
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.ibmwatsonx;
|
||||
|
||||
import org.elasticsearch.xpack.core.inference.results.UnifiedChatCompletionException;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpResult;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
|
||||
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.response.IbmWatsonxErrorResponseEntity;
|
||||
import org.elasticsearch.xpack.inference.services.openai.OpenAiUnifiedChatCompletionResponseHandler;
|
||||
|
||||
import java.util.Locale;
|
||||
|
||||
/**
|
||||
* Handles streaming chat completion responses and error parsing for Watsonx inference endpoints.
|
||||
* Adapts the OpenAI handler to support Watsonx's error schema.
|
||||
*/
|
||||
public class IbmWatsonUnifiedChatCompletionResponseHandler extends OpenAiUnifiedChatCompletionResponseHandler {
|
||||
|
||||
private static final String WATSONX_ERROR = "watsonx_error";
|
||||
|
||||
public IbmWatsonUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
|
||||
super(requestType, parseFunction, IbmWatsonxErrorResponseEntity::fromResponse);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Exception buildError(String message, Request request, HttpResult result, ErrorResponse errorResponse) {
|
||||
assert request.isStreaming() : "Only streaming requests support this format";
|
||||
var responseStatusCode = result.response().getStatusLine().getStatusCode();
|
||||
if (request.isStreaming()) {
|
||||
var errorMessage = errorMessage(message, request, result, errorResponse, responseStatusCode);
|
||||
var restStatus = toRestStatus(responseStatusCode);
|
||||
return errorResponse instanceof IbmWatsonxErrorResponseEntity
|
||||
? new UnifiedChatCompletionException(restStatus, errorMessage, WATSONX_ERROR, restStatus.name().toLowerCase(Locale.ROOT))
|
||||
: new UnifiedChatCompletionException(
|
||||
restStatus,
|
||||
errorMessage,
|
||||
createErrorType(errorResponse),
|
||||
restStatus.name().toLowerCase(Locale.ROOT)
|
||||
);
|
||||
} else {
|
||||
return super.buildError(message, request, result, errorResponse);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.ibmwatsonx;
|
||||
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.response.IbmWatsonxErrorResponseEntity;
|
||||
import org.elasticsearch.xpack.inference.services.openai.OpenAiChatCompletionResponseHandler;
|
||||
|
||||
public class IbmWatsonxCompletionResponseHandler extends OpenAiChatCompletionResponseHandler {
|
||||
|
||||
/**
|
||||
* Constructs a IbmWatsonxCompletionResponseHandler with the specified request type and response parser.
|
||||
*
|
||||
* @param requestType The type of request being handled (e.g., "IBM watsonx completions").
|
||||
* @param parseFunction The function to parse the response.
|
||||
*/
|
||||
public IbmWatsonxCompletionResponseHandler(String requestType, ResponseParser parseFunction) {
|
||||
super(requestType, parseFunction, IbmWatsonxErrorResponseEntity::fromResponse);
|
||||
}
|
||||
}
|
|
@ -35,7 +35,7 @@ public class IbmWatsonxEmbeddingsRequestManager extends IbmWatsonxRequestManager
|
|||
private static final ResponseHandler HANDLER = createEmbeddingsHandler();
|
||||
|
||||
private static ResponseHandler createEmbeddingsHandler() {
|
||||
return new IbmWatsonxResponseHandler("ibm watsonx embeddings", IbmWatsonxEmbeddingsResponseEntity::fromResponse);
|
||||
return new IbmWatsonxResponseHandler("IBM watsonx embeddings", IbmWatsonxEmbeddingsResponseEntity::fromResponse);
|
||||
}
|
||||
|
||||
private final IbmWatsonxEmbeddingsModel model;
|
||||
|
|
|
@ -7,18 +7,19 @@
|
|||
|
||||
package org.elasticsearch.xpack.inference.services.ibmwatsonx;
|
||||
|
||||
import org.elasticsearch.inference.Model;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.ModelSecrets;
|
||||
import org.elasticsearch.inference.ServiceSettings;
|
||||
import org.elasticsearch.inference.TaskSettings;
|
||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.services.RateLimitGroupingModel;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.action.IbmWatsonxActionVisitor;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
public abstract class IbmWatsonxModel extends Model {
|
||||
public abstract class IbmWatsonxModel extends RateLimitGroupingModel {
|
||||
|
||||
private final IbmWatsonxRateLimitServiceSettings rateLimitServiceSettings;
|
||||
|
||||
|
@ -49,4 +50,14 @@ public abstract class IbmWatsonxModel extends Model {
|
|||
public IbmWatsonxRateLimitServiceSettings rateLimitServiceSettings() {
|
||||
return rateLimitServiceSettings;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int rateLimitGroupingHash() {
|
||||
return Objects.hash(this.rateLimitServiceSettings);
|
||||
}
|
||||
|
||||
@Override
|
||||
public RateLimitSettings rateLimitSettings() {
|
||||
return this.rateLimitServiceSettings().rateLimitSettings();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,7 +31,7 @@ public class IbmWatsonxRerankRequestManager extends IbmWatsonxRequestManager {
|
|||
|
||||
private static ResponseHandler createIbmWatsonxResponseHandler() {
|
||||
return new IbmWatsonxResponseHandler(
|
||||
"ibm watsonx rerank",
|
||||
"IBM watsonx rerank",
|
||||
(request, response) -> IbmWatsonxRankedResponseEntity.fromResponse(response)
|
||||
);
|
||||
}
|
||||
|
|
|
@ -30,7 +30,10 @@ import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType;
|
|||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
|
||||
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
|
||||
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
||||
|
@ -40,14 +43,18 @@ import org.elasticsearch.xpack.inference.services.SenderService;
|
|||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceUtils;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.action.IbmWatsonxActionCreator;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModel;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.request.IbmWatsonxChatCompletionRequest;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;
|
||||
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
|
||||
|
||||
import java.util.EnumSet;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
|
||||
|
@ -56,7 +63,6 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersi
|
|||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation;
|
||||
import static org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings.URL;
|
||||
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.API_VERSION;
|
||||
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.EMBEDDING_MAX_BATCH_SIZE;
|
||||
|
@ -66,8 +72,16 @@ public class IbmWatsonxService extends SenderService {
|
|||
|
||||
public static final String NAME = "watsonxai";
|
||||
|
||||
private static final String SERVICE_NAME = "IBM Watsonx";
|
||||
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(TaskType.TEXT_EMBEDDING);
|
||||
private static final String SERVICE_NAME = "IBM watsonx";
|
||||
private static final EnumSet<TaskType> supportedTaskTypes = EnumSet.of(
|
||||
TaskType.TEXT_EMBEDDING,
|
||||
TaskType.COMPLETION,
|
||||
TaskType.CHAT_COMPLETION
|
||||
);
|
||||
private static final ResponseHandler UNIFIED_CHAT_COMPLETION_HANDLER = new IbmWatsonUnifiedChatCompletionResponseHandler(
|
||||
"IBM watsonx chat completions",
|
||||
OpenAiChatCompletionResponseEntity::fromResponse
|
||||
);
|
||||
|
||||
public IbmWatsonxService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) {
|
||||
super(factory, serviceComponents);
|
||||
|
@ -148,6 +162,14 @@ public class IbmWatsonxService extends SenderService {
|
|||
secretSettings,
|
||||
context
|
||||
);
|
||||
case CHAT_COMPLETION, COMPLETION -> new IbmWatsonxChatCompletionModel(
|
||||
inferenceEntityId,
|
||||
taskType,
|
||||
NAME,
|
||||
serviceSettings,
|
||||
secretSettings,
|
||||
context
|
||||
);
|
||||
default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST);
|
||||
};
|
||||
}
|
||||
|
@ -236,6 +258,11 @@ public class IbmWatsonxService extends SenderService {
|
|||
return TransportVersions.V_8_16_0;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Set<TaskType> supportedStreamingTasks() {
|
||||
return EnumSet.of(TaskType.COMPLETION, TaskType.CHAT_COMPLETION);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Model updateModelWithEmbeddingDetails(Model model, int embeddingSize) {
|
||||
if (model instanceof IbmWatsonxEmbeddingsModel embeddingsModel) {
|
||||
|
@ -291,7 +318,24 @@ public class IbmWatsonxService extends SenderService {
|
|||
TimeValue timeout,
|
||||
ActionListener<InferenceServiceResults> listener
|
||||
) {
|
||||
throwUnsupportedUnifiedCompletionOperation(NAME);
|
||||
if (model instanceof IbmWatsonxChatCompletionModel == false) {
|
||||
listener.onFailure(createInvalidModelException(model));
|
||||
return;
|
||||
}
|
||||
|
||||
IbmWatsonxChatCompletionModel ibmWatsonxChatCompletionModel = (IbmWatsonxChatCompletionModel) model;
|
||||
var overriddenModel = IbmWatsonxChatCompletionModel.of(ibmWatsonxChatCompletionModel, inputs.getRequest());
|
||||
var manager = new GenericRequestManager<>(
|
||||
getServiceComponents().threadPool(),
|
||||
overriddenModel,
|
||||
UNIFIED_CHAT_COMPLETION_HANDLER,
|
||||
unifiedChatInput -> new IbmWatsonxChatCompletionRequest(unifiedChatInput, overriddenModel),
|
||||
UnifiedChatInput.class
|
||||
);
|
||||
var errorMessage = IbmWatsonxActionCreator.buildErrorMessage(TaskType.CHAT_COMPLETION, model.getInferenceEntityId());
|
||||
var action = new SenderExecutableAction(getSender(), manager, errorMessage);
|
||||
|
||||
action.execute(inputs, timeout, listener);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
@ -331,7 +375,7 @@ public class IbmWatsonxService extends SenderService {
|
|||
|
||||
configurationMap.put(
|
||||
API_VERSION,
|
||||
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("The IBM Watsonx API version ID to use.")
|
||||
new SettingsConfiguration.Builder(supportedTaskTypes).setDescription("The IBM watsonx API version ID to use.")
|
||||
.setLabel("API Version")
|
||||
.setRequired(true)
|
||||
.setSensitive(false)
|
||||
|
|
|
@ -7,26 +7,48 @@
|
|||
|
||||
package org.elasticsearch.xpack.inference.services.ibmwatsonx.action;
|
||||
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.inference.common.Truncator;
|
||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceComponents;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxCompletionResponseHandler;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxEmbeddingsRequestManager;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxRerankRequestManager;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModel;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.request.IbmWatsonxChatCompletionRequest;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;
|
||||
import org.elasticsearch.xpack.inference.services.openai.response.OpenAiChatCompletionResponseEntity;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.core.Strings.format;
|
||||
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
|
||||
|
||||
/**
|
||||
* IbmWatsonxActionCreator is responsible for creating executable actions for various models.
|
||||
* It implements the IbmWatsonxActionVisitor interface to provide specific implementations.
|
||||
*/
|
||||
public class IbmWatsonxActionCreator implements IbmWatsonxActionVisitor {
|
||||
private final Sender sender;
|
||||
private final ServiceComponents serviceComponents;
|
||||
|
||||
static final String COMPLETION_REQUEST_TYPE = "IBM watsonx completions";
|
||||
static final String USER_ROLE = "user";
|
||||
static final ResponseHandler COMPLETION_HANDLER = new IbmWatsonxCompletionResponseHandler(
|
||||
COMPLETION_REQUEST_TYPE,
|
||||
OpenAiChatCompletionResponseEntity::fromResponse
|
||||
);
|
||||
|
||||
public IbmWatsonxActionCreator(Sender sender, ServiceComponents serviceComponents) {
|
||||
this.sender = Objects.requireNonNull(sender);
|
||||
this.serviceComponents = Objects.requireNonNull(serviceComponents);
|
||||
|
@ -34,7 +56,7 @@ public class IbmWatsonxActionCreator implements IbmWatsonxActionVisitor {
|
|||
|
||||
@Override
|
||||
public ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map<String, Object> taskSettings) {
|
||||
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("IBM WatsonX embeddings");
|
||||
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("IBM watsonx embeddings");
|
||||
return new SenderExecutableAction(
|
||||
sender,
|
||||
getEmbeddingsRequestManager(model, serviceComponents.truncator(), serviceComponents.threadPool()),
|
||||
|
@ -46,10 +68,24 @@ public class IbmWatsonxActionCreator implements IbmWatsonxActionVisitor {
|
|||
public ExecutableAction create(IbmWatsonxRerankModel model, Map<String, Object> taskSettings) {
|
||||
var overriddenModel = IbmWatsonxRerankModel.of(model, taskSettings);
|
||||
var requestCreator = IbmWatsonxRerankRequestManager.of(overriddenModel, serviceComponents.threadPool());
|
||||
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("Ibm Watsonx rerank");
|
||||
var failedToSendRequestErrorMessage = buildErrorMessage(TaskType.RERANK, overriddenModel.getInferenceEntityId());
|
||||
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ExecutableAction create(IbmWatsonxChatCompletionModel chatCompletionModel) {
|
||||
var manager = new GenericRequestManager<>(
|
||||
serviceComponents.threadPool(),
|
||||
chatCompletionModel,
|
||||
COMPLETION_HANDLER,
|
||||
inputs -> new IbmWatsonxChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), chatCompletionModel),
|
||||
ChatCompletionInput.class
|
||||
);
|
||||
|
||||
var failedToSendRequestErrorMessage = buildErrorMessage(TaskType.COMPLETION, chatCompletionModel.getInferenceEntityId());
|
||||
return new SingleInputSenderExecutableAction(sender, manager, failedToSendRequestErrorMessage, COMPLETION_REQUEST_TYPE);
|
||||
}
|
||||
|
||||
protected IbmWatsonxEmbeddingsRequestManager getEmbeddingsRequestManager(
|
||||
IbmWatsonxEmbeddingsModel model,
|
||||
Truncator truncator,
|
||||
|
@ -57,4 +93,15 @@ public class IbmWatsonxActionCreator implements IbmWatsonxActionVisitor {
|
|||
) {
|
||||
return new IbmWatsonxEmbeddingsRequestManager(model, truncator, threadPool);
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds an error message for IBM watsonx actions.
|
||||
*
|
||||
* @param requestType The type of request (e.g. COMPLETION, EMBEDDING, RERANK).
|
||||
* @param inferenceId The ID of the inference entity.
|
||||
* @return A formatted error message.
|
||||
*/
|
||||
public static String buildErrorMessage(TaskType requestType, String inferenceId) {
|
||||
return format("Failed to send IBM watsonx %s request from inference entity id [%s]", requestType.toString(), inferenceId);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,13 +8,42 @@
|
|||
package org.elasticsearch.xpack.inference.services.ibmwatsonx.action;
|
||||
|
||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModel;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;
|
||||
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* Interface for creating {@link ExecutableAction} instances for IBM watsonx models.
|
||||
* <p>
|
||||
* This interface is used to create {@link ExecutableAction} instances for different types of IBM watsonx models, such as
|
||||
* {@link IbmWatsonxEmbeddingsModel} and {@link IbmWatsonxRerankModel} and {@link IbmWatsonxChatCompletionModel}.
|
||||
*/
|
||||
public interface IbmWatsonxActionVisitor {
|
||||
|
||||
/**
|
||||
* Creates an {@link ExecutableAction} for the given {@link IbmWatsonxEmbeddingsModel}.
|
||||
*
|
||||
* @param model The model to create the action for.
|
||||
* @param taskSettings The task settings to use.
|
||||
* @return An {@link ExecutableAction} for the given model.
|
||||
*/
|
||||
ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map<String, Object> taskSettings);
|
||||
|
||||
/**
|
||||
* Creates an {@link ExecutableAction} for the given {@link IbmWatsonxRerankModel}.
|
||||
*
|
||||
* @param model The model to create the action for.
|
||||
* @return An {@link ExecutableAction} for the given model.
|
||||
*/
|
||||
ExecutableAction create(IbmWatsonxRerankModel model, Map<String, Object> taskSettings);
|
||||
|
||||
/**
|
||||
* Creates an {@link ExecutableAction} for the given {@link IbmWatsonxChatCompletionModel}.
|
||||
*
|
||||
* @param model The model to create the action for.
|
||||
* @return An {@link ExecutableAction} for the given model.
|
||||
*/
|
||||
ExecutableAction create(IbmWatsonxChatCompletionModel model);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,143 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.ibmwatsonx.completion;
|
||||
|
||||
import org.apache.http.client.utils.URIBuilder;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.ModelSecrets;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.inference.UnifiedCompletionRequest;
|
||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxModel;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.action.IbmWatsonxActionVisitor;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
|
||||
import java.net.URI;
|
||||
import java.net.URISyntaxException;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.request.IbmWatsonxUtils.COMPLETIONS;
|
||||
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.request.IbmWatsonxUtils.ML;
|
||||
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.request.IbmWatsonxUtils.TEXT;
|
||||
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.request.IbmWatsonxUtils.V1;
|
||||
|
||||
public class IbmWatsonxChatCompletionModel extends IbmWatsonxModel {
|
||||
|
||||
/**
|
||||
* Constructor for IbmWatsonxChatCompletionModel.
|
||||
*
|
||||
* @param inferenceEntityId The unique identifier for the inference entity.
|
||||
* @param taskType The type of task this model is designed for.
|
||||
* @param service The name of the service this model belongs to.
|
||||
* @param serviceSettings The settings specific to the Ibm Granite chat completion service.
|
||||
* @param secrets The secrets required for accessing the service.
|
||||
* @param context The context for parsing configuration settings.
|
||||
*/
|
||||
public IbmWatsonxChatCompletionModel(
|
||||
String inferenceEntityId,
|
||||
TaskType taskType,
|
||||
String service,
|
||||
Map<String, Object> serviceSettings,
|
||||
@Nullable Map<String, Object> secrets,
|
||||
ConfigurationParseContext context
|
||||
) {
|
||||
this(
|
||||
inferenceEntityId,
|
||||
taskType,
|
||||
service,
|
||||
IbmWatsonxChatCompletionServiceSettings.fromMap(serviceSettings, context),
|
||||
DefaultSecretSettings.fromMap(secrets)
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new IbmWatsonxChatCompletionModel with overridden service settings.
|
||||
*
|
||||
* @param model The original IbmWatsonxChatCompletionModel.
|
||||
* @param request The UnifiedCompletionRequest containing the model override.
|
||||
* @return A new IbmWatsonxChatCompletionModel with the overridden model ID.
|
||||
*/
|
||||
public static IbmWatsonxChatCompletionModel of(IbmWatsonxChatCompletionModel model, UnifiedCompletionRequest request) {
|
||||
if (request.model() == null) {
|
||||
// If no model is specified in the request, return the original model
|
||||
return model;
|
||||
}
|
||||
|
||||
var originalModelServiceSettings = model.getServiceSettings();
|
||||
var overriddenServiceSettings = new IbmWatsonxChatCompletionServiceSettings(
|
||||
originalModelServiceSettings.uri(),
|
||||
originalModelServiceSettings.apiVersion(),
|
||||
request.model(),
|
||||
originalModelServiceSettings.projectId(),
|
||||
originalModelServiceSettings.rateLimitSettings()
|
||||
);
|
||||
|
||||
return new IbmWatsonxChatCompletionModel(
|
||||
model.getInferenceEntityId(),
|
||||
model.getTaskType(),
|
||||
model.getConfigurations().getService(),
|
||||
overriddenServiceSettings,
|
||||
model.getSecretSettings()
|
||||
);
|
||||
}
|
||||
|
||||
// should only be used for testing
|
||||
IbmWatsonxChatCompletionModel(
|
||||
String inferenceEntityId,
|
||||
TaskType taskType,
|
||||
String service,
|
||||
IbmWatsonxChatCompletionServiceSettings serviceSettings,
|
||||
@Nullable DefaultSecretSettings secretSettings
|
||||
) {
|
||||
super(
|
||||
new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings),
|
||||
new ModelSecrets(secretSettings),
|
||||
serviceSettings
|
||||
);
|
||||
}
|
||||
|
||||
@Override
|
||||
public IbmWatsonxChatCompletionServiceSettings getServiceSettings() {
|
||||
return (IbmWatsonxChatCompletionServiceSettings) super.getServiceSettings();
|
||||
}
|
||||
|
||||
@Override
|
||||
public DefaultSecretSettings getSecretSettings() {
|
||||
return (DefaultSecretSettings) super.getSecretSettings();
|
||||
}
|
||||
|
||||
public URI uri() {
|
||||
URI uri;
|
||||
try {
|
||||
uri = buildUri(this.getServiceSettings().uri().toString(), this.getServiceSettings().apiVersion());
|
||||
} catch (URISyntaxException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
return uri;
|
||||
}
|
||||
|
||||
/**
|
||||
* Accepts a visitor to create an executable action. The returned action will not return documents in the response.
|
||||
* @param visitor Interface for creating {@link ExecutableAction} instances for IBM watsonx models.
|
||||
* @return the completion action
|
||||
*/
|
||||
public ExecutableAction accept(IbmWatsonxActionVisitor visitor, Map<String, Object> taskSettings) {
|
||||
return visitor.create(this);
|
||||
}
|
||||
|
||||
public static URI buildUri(String uri, String apiVersion) throws URISyntaxException {
|
||||
return new URIBuilder().setScheme("https")
|
||||
.setHost(uri)
|
||||
.setPathSegments(ML, V1, TEXT, COMPLETIONS)
|
||||
.setParameter("version", apiVersion)
|
||||
.build();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,193 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.ibmwatsonx.completion;
|
||||
|
||||
import org.elasticsearch.TransportVersion;
|
||||
import org.elasticsearch.TransportVersions;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.io.stream.StreamInput;
|
||||
import org.elasticsearch.common.io.stream.StreamOutput;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.inference.ModelConfigurations;
|
||||
import org.elasticsearch.inference.ServiceSettings;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxRateLimitServiceSettings;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxService;
|
||||
import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
import java.util.Map;
|
||||
import java.util.Objects;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceFields.URL;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createUri;
|
||||
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
|
||||
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.API_VERSION;
|
||||
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.PROJECT_ID;
|
||||
|
||||
public class IbmWatsonxChatCompletionServiceSettings extends FilteredXContentObject
|
||||
implements
|
||||
ServiceSettings,
|
||||
IbmWatsonxRateLimitServiceSettings {
|
||||
public static final String NAME = "ibm_watsonx_completion_service_settings";
|
||||
|
||||
/**
|
||||
* Rate limits are defined at
|
||||
* <a href="https://www.ibm.com/docs/en/watsonx/saas?topic=learning-watson-machine-plans">Watson Machine Learning plans</a>.
|
||||
* For the Lite plan, the limit is 120 requests per minute.
|
||||
*/
|
||||
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(120);
|
||||
|
||||
public static IbmWatsonxChatCompletionServiceSettings fromMap(Map<String, Object> map, ConfigurationParseContext context) {
|
||||
ValidationException validationException = new ValidationException();
|
||||
|
||||
String url = extractRequiredString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
String apiVersion = extractRequiredString(map, API_VERSION, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
|
||||
String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
String projectId = extractRequiredString(map, PROJECT_ID, ModelConfigurations.SERVICE_SETTINGS, validationException);
|
||||
|
||||
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
|
||||
map,
|
||||
DEFAULT_RATE_LIMIT_SETTINGS,
|
||||
validationException,
|
||||
IbmWatsonxService.NAME,
|
||||
context
|
||||
);
|
||||
|
||||
if (validationException.validationErrors().isEmpty() == false) {
|
||||
throw validationException;
|
||||
}
|
||||
|
||||
return new IbmWatsonxChatCompletionServiceSettings(uri, apiVersion, modelId, projectId, rateLimitSettings);
|
||||
}
|
||||
|
||||
private final URI uri;
|
||||
|
||||
private final String apiVersion;
|
||||
|
||||
private final String modelId;
|
||||
|
||||
private final String projectId;
|
||||
|
||||
private final RateLimitSettings rateLimitSettings;
|
||||
|
||||
public IbmWatsonxChatCompletionServiceSettings(
|
||||
URI uri,
|
||||
String apiVersion,
|
||||
String modelId,
|
||||
String projectId,
|
||||
@Nullable RateLimitSettings rateLimitSettings
|
||||
) {
|
||||
this.uri = uri;
|
||||
this.apiVersion = apiVersion;
|
||||
this.projectId = projectId;
|
||||
this.modelId = modelId;
|
||||
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
|
||||
}
|
||||
|
||||
public IbmWatsonxChatCompletionServiceSettings(StreamInput in) throws IOException {
|
||||
this.uri = createUri(in.readString());
|
||||
this.apiVersion = in.readString();
|
||||
this.modelId = in.readString();
|
||||
this.projectId = in.readString();
|
||||
this.rateLimitSettings = new RateLimitSettings(in);
|
||||
|
||||
}
|
||||
|
||||
public URI uri() {
|
||||
return uri;
|
||||
}
|
||||
|
||||
public String apiVersion() {
|
||||
return apiVersion;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String modelId() {
|
||||
return modelId;
|
||||
}
|
||||
|
||||
public String projectId() {
|
||||
return projectId;
|
||||
}
|
||||
|
||||
@Override
|
||||
public RateLimitSettings rateLimitSettings() {
|
||||
return rateLimitSettings;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getWriteableName() {
|
||||
return NAME;
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
|
||||
toXContentFragmentOfExposedFields(builder, params);
|
||||
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.field(URL, uri.toString());
|
||||
|
||||
builder.field(API_VERSION, apiVersion);
|
||||
|
||||
builder.field(MODEL_ID, modelId);
|
||||
|
||||
builder.field(PROJECT_ID, projectId);
|
||||
|
||||
rateLimitSettings.toXContent(builder, params);
|
||||
|
||||
return builder;
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransportVersion getMinimalSupportedVersion() {
|
||||
return TransportVersions.ML_INFERENCE_IBM_WATSONX_COMPLETION_ADDED;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void writeTo(StreamOutput out) throws IOException {
|
||||
out.writeString(uri.toString());
|
||||
out.writeString(apiVersion);
|
||||
|
||||
out.writeString(modelId);
|
||||
out.writeString(projectId);
|
||||
|
||||
rateLimitSettings.writeTo(out);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object object) {
|
||||
if (this == object) return true;
|
||||
if (object == null || getClass() != object.getClass()) return false;
|
||||
IbmWatsonxChatCompletionServiceSettings that = (IbmWatsonxChatCompletionServiceSettings) object;
|
||||
return Objects.equals(uri, that.uri)
|
||||
&& Objects.equals(apiVersion, that.apiVersion)
|
||||
&& Objects.equals(modelId, that.modelId)
|
||||
&& Objects.equals(projectId, that.projectId)
|
||||
&& Objects.equals(rateLimitSettings, that.rateLimitSettings);
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return Objects.hash(uri, apiVersion, modelId, projectId, rateLimitSettings);
|
||||
}
|
||||
}
|
|
@ -52,7 +52,7 @@ public class IbmWatsonxEmbeddingsServiceSettings extends FilteredXContentObject
|
|||
/**
|
||||
* Rate limits are defined at
|
||||
* <a href="https://www.ibm.com/docs/en/watsonx/saas?topic=learning-watson-machine-plans">Watson Machine Learning plans</a>.
|
||||
* For Lite plan, you've 120 requests per minute.
|
||||
* For the Lite plan, the limit is 120 requests per minute.
|
||||
*/
|
||||
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(120);
|
||||
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.ibmwatsonx.request;
|
||||
|
||||
import org.apache.http.HttpHeaders;
|
||||
import org.apache.http.client.methods.HttpPost;
|
||||
import org.apache.http.entity.ByteArrayEntity;
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
|
||||
import org.elasticsearch.xpack.inference.external.request.Request;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModel;
|
||||
|
||||
import java.net.URI;
|
||||
import java.nio.charset.StandardCharsets;
|
||||
import java.util.Objects;
|
||||
|
||||
public class IbmWatsonxChatCompletionRequest implements IbmWatsonxRequest {
|
||||
private final IbmWatsonxChatCompletionModel model;
|
||||
private final UnifiedChatInput chatInput;
|
||||
|
||||
public IbmWatsonxChatCompletionRequest(UnifiedChatInput chatInput, IbmWatsonxChatCompletionModel model) {
|
||||
this.chatInput = Objects.requireNonNull(chatInput);
|
||||
this.model = Objects.requireNonNull(model);
|
||||
}
|
||||
|
||||
@Override
|
||||
public HttpRequest createHttpRequest() {
|
||||
HttpPost httpPost = new HttpPost(model.uri());
|
||||
|
||||
ByteArrayEntity byteEntity = new ByteArrayEntity(
|
||||
Strings.toString(new IbmWatsonxChatCompletionRequestEntity(chatInput, model)).getBytes(StandardCharsets.UTF_8)
|
||||
);
|
||||
httpPost.setEntity(byteEntity);
|
||||
|
||||
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());
|
||||
|
||||
decorateWithAuth(httpPost);
|
||||
|
||||
return new HttpRequest(httpPost, getInferenceEntityId());
|
||||
}
|
||||
|
||||
@Override
|
||||
public URI getURI() {
|
||||
return model.uri();
|
||||
}
|
||||
|
||||
public void decorateWithAuth(HttpPost httpPost) {
|
||||
IbmWatsonxRequest.decorateWithBearerToken(httpPost, model.getSecretSettings(), model.getInferenceEntityId());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Request truncate() {
|
||||
// No truncation for IBM watsonx chat completions
|
||||
return this;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean[] getTruncationInfo() {
|
||||
// No truncation for IBM watsonx chat completions
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getInferenceEntityId() {
|
||||
return model.getInferenceEntityId();
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isStreaming() {
|
||||
return chatInput.stream();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.ibmwatsonx.request;
|
||||
|
||||
import org.elasticsearch.inference.UnifiedCompletionRequest;
|
||||
import org.elasticsearch.xcontent.ToXContentObject;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||
import org.elasticsearch.xpack.inference.external.unified.UnifiedChatCompletionRequestEntity;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModel;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* IbmWatsonxChatCompletionRequestEntity is responsible for creating the request entity for Watsonx chat completion.
|
||||
* It implements ToXContentObject to allow serialization to XContent format.
|
||||
*/
|
||||
public class IbmWatsonxChatCompletionRequestEntity implements ToXContentObject {
|
||||
|
||||
private final IbmWatsonxChatCompletionModel model;
|
||||
private final UnifiedChatCompletionRequestEntity unifiedRequestEntity;
|
||||
|
||||
private static final String PROJECT_ID_FIELD = "project_id";
|
||||
|
||||
public IbmWatsonxChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, IbmWatsonxChatCompletionModel model) {
|
||||
this.unifiedRequestEntity = new UnifiedChatCompletionRequestEntity(unifiedChatInput);
|
||||
this.model = Objects.requireNonNull(model);
|
||||
}
|
||||
|
||||
@Override
|
||||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
|
||||
builder.startObject();
|
||||
builder.field(PROJECT_ID_FIELD, model.getServiceSettings().projectId());
|
||||
unifiedRequestEntity.toXContent(
|
||||
builder,
|
||||
UnifiedCompletionRequest.withMaxTokensAndSkipStreamOptionsField(model.getServiceSettings().modelId(), params)
|
||||
);
|
||||
builder.endObject();
|
||||
return builder;
|
||||
}
|
||||
}
|
|
@ -14,6 +14,7 @@ public class IbmWatsonxUtils {
|
|||
public static final String TEXT = "text";
|
||||
public static final String EMBEDDINGS = "embeddings";
|
||||
public static final String RERANKS = "reranks";
|
||||
public static final String COMPLETIONS = "chat";
|
||||
|
||||
private IbmWatsonxUtils() {}
|
||||
|
||||
|
|
|
@ -100,8 +100,8 @@ public class IbmWatsonxRerankModel extends IbmWatsonxModel {
|
|||
|
||||
/**
|
||||
* Accepts a visitor to create an executable action. The returned action will not return documents in the response.
|
||||
* @param visitor _
|
||||
* @param taskSettings _
|
||||
* @param visitor Interface for creating {@link ExecutableAction} instances for IBM watsonx models.
|
||||
* @param taskSettings Settings in the request to override the model's defaults
|
||||
* @return the rerank action
|
||||
*/
|
||||
@Override
|
||||
|
|
|
@ -41,7 +41,7 @@ public class IbmWatsonxRerankServiceSettings extends FilteredXContentObject impl
|
|||
/**
|
||||
* Rate limits are defined at
|
||||
* <a href="https://www.ibm.com/docs/en/watsonx/saas?topic=learning-watson-machine-plans">Watson Machine Learning plans</a>.
|
||||
* For Lite plan, you've 120 requests per minute.
|
||||
* For the Lite plan, the limit is 120 requests per minute.
|
||||
*/
|
||||
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(120);
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ import static org.elasticsearch.xpack.inference.external.response.XContentUtils.
|
|||
|
||||
public class IbmWatsonxEmbeddingsResponseEntity {
|
||||
|
||||
private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in IBM Watsonx embeddings response";
|
||||
private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in IBM watsonx embeddings response";
|
||||
|
||||
public static TextEmbeddingFloatResults fromResponse(Request request, HttpResult response) throws IOException {
|
||||
var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE);
|
||||
|
|
|
@ -32,7 +32,7 @@ public class IbmWatsonxRankedResponseEntity {
|
|||
private static final Logger logger = LogManager.getLogger(IbmWatsonxRankedResponseEntity.class);
|
||||
|
||||
/**
|
||||
* Parses the Ibm Watsonx ranked response.
|
||||
* Parses the IBM watsonx ranked response.
|
||||
*
|
||||
* For a request like:
|
||||
* "model": "rerank-english-v2.0",
|
||||
|
@ -71,7 +71,7 @@ public class IbmWatsonxRankedResponseEntity {
|
|||
* ],
|
||||
* }
|
||||
*
|
||||
* @param response the http response from ibm watsonx
|
||||
* @param response the http response from IBM watsonx
|
||||
* @return the parsed response
|
||||
* @throws IOException if there is an error parsing the response
|
||||
*/
|
||||
|
|
|
@ -84,8 +84,8 @@ public class JinaAIRerankModel extends JinaAIModel {
|
|||
|
||||
/**
|
||||
* Accepts a visitor to create an executable action. The returned action will not return documents in the response.
|
||||
* @param visitor _
|
||||
* @param taskSettings _
|
||||
* @param visitor Interface for creating {@link ExecutableAction} instances for Jina AI models.
|
||||
* @param taskSettings Settings in the request to override the model's defaults
|
||||
* @return the rerank action
|
||||
*/
|
||||
@Override
|
||||
|
|
|
@ -37,8 +37,8 @@ import static org.elasticsearch.core.Strings.format;
|
|||
public class MistralActionCreator implements MistralActionVisitor {
|
||||
|
||||
public static final String COMPLETION_ERROR_PREFIX = "Mistral completions";
|
||||
static final String USER_ROLE = "user";
|
||||
static final ResponseHandler COMPLETION_HANDLER = new MistralCompletionResponseHandler(
|
||||
public static final String USER_ROLE = "user";
|
||||
public static final ResponseHandler COMPLETION_HANDLER = new MistralCompletionResponseHandler(
|
||||
"mistral completions",
|
||||
OpenAiChatCompletionResponseEntity::fromResponse
|
||||
);
|
||||
|
|
|
@ -109,8 +109,8 @@ public class VoyageAIRerankModel extends VoyageAIModel {
|
|||
|
||||
/**
|
||||
* Accepts a visitor to create an executable action. The returned action will not return documents in the response.
|
||||
* @param visitor _
|
||||
* @param taskSettings _
|
||||
* @param visitor Interface for creating {@link ExecutableAction} instances for Voyage AI models.
|
||||
* @param taskSettings Settings in the request to override the model's defaults
|
||||
* @return the rerank action
|
||||
*/
|
||||
@Override
|
||||
|
|
|
@ -0,0 +1,154 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services;
|
||||
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.support.PlainActionFuture;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.core.TimeValue;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.test.http.MockResponse;
|
||||
import org.elasticsearch.test.http.MockWebServer;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
||||
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.URISyntaxException;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
|
||||
import static org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests.createSender;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
public abstract class ChatCompletionActionTests extends ESTestCase {
|
||||
protected static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
|
||||
protected final MockWebServer webServer = new MockWebServer();
|
||||
protected HttpClientManager clientManager;
|
||||
protected ThreadPool threadPool;
|
||||
|
||||
protected abstract ExecutableAction createAction(String url, Sender sender) throws URISyntaxException;
|
||||
|
||||
protected abstract String getOneInputError();
|
||||
|
||||
protected abstract String getFailedToSendError();
|
||||
|
||||
@Before
|
||||
public void init() throws Exception {
|
||||
webServer.start();
|
||||
threadPool = createThreadPool(inferenceUtilityPool());
|
||||
clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
|
||||
}
|
||||
|
||||
@After
|
||||
public void shutdown() throws IOException {
|
||||
clientManager.close();
|
||||
terminate(threadPool);
|
||||
webServer.close();
|
||||
}
|
||||
|
||||
public void testExecute_ThrowsElasticsearchException() throws URISyntaxException {
|
||||
var sender = mock(Sender.class);
|
||||
doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any());
|
||||
|
||||
var action = createAction(getUrl(webServer), sender);
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||
|
||||
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
|
||||
|
||||
assertThat(thrownException.getMessage(), is("failed"));
|
||||
}
|
||||
|
||||
public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() throws URISyntaxException {
|
||||
var sender = mock(Sender.class);
|
||||
|
||||
doAnswer(invocation -> {
|
||||
ActionListener<InferenceServiceResults> listener = invocation.getArgument(3);
|
||||
listener.onFailure(new IllegalStateException("failed"));
|
||||
|
||||
return Void.TYPE;
|
||||
}).when(sender).send(any(), any(), any(), any());
|
||||
|
||||
var action = createAction(getUrl(webServer), sender);
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||
|
||||
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
|
||||
|
||||
assertThat(thrownException.getMessage(), is(getFailedToSendError()));
|
||||
}
|
||||
|
||||
public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException, URISyntaxException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(getResponseJson()));
|
||||
|
||||
var action = createAction(getUrl(webServer), sender);
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||
|
||||
var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
|
||||
|
||||
assertThat(thrownException.getMessage(), is(getOneInputError()));
|
||||
assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST));
|
||||
}
|
||||
}
|
||||
|
||||
protected String getResponseJson() {
|
||||
return """
|
||||
{
|
||||
"id": "9d80f26810ac4e9582f927fcf0512ec7",
|
||||
"object": "chat.completion",
|
||||
"created": 1748596419,
|
||||
"model": "modelId",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"tool_calls": null,
|
||||
"content": "result content"
|
||||
},
|
||||
"finish_reason": "length",
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"total_tokens": 11,
|
||||
"completion_tokens": 1
|
||||
}
|
||||
}
|
||||
""";
|
||||
}
|
||||
}
|
|
@ -918,8 +918,8 @@ public class IbmWatsonxServiceTests extends ESTestCase {
|
|||
String content = XContentHelper.stripWhitespace("""
|
||||
{
|
||||
"service": "watsonxai",
|
||||
"name": "IBM Watsonx",
|
||||
"task_types": ["text_embedding"],
|
||||
"name": "IBM watsonx",
|
||||
"task_types": ["text_embedding", "completion", "chat_completion"],
|
||||
"configurations": {
|
||||
"project_id": {
|
||||
"description": "",
|
||||
|
@ -928,7 +928,7 @@ public class IbmWatsonxServiceTests extends ESTestCase {
|
|||
"sensitive": false,
|
||||
"updatable": false,
|
||||
"type": "str",
|
||||
"supported_task_types": ["text_embedding"]
|
||||
"supported_task_types": ["text_embedding", "completion", "chat_completion"]
|
||||
},
|
||||
"model_id": {
|
||||
"description": "The name of the model to use for the inference task.",
|
||||
|
@ -937,16 +937,16 @@ public class IbmWatsonxServiceTests extends ESTestCase {
|
|||
"sensitive": false,
|
||||
"updatable": false,
|
||||
"type": "str",
|
||||
"supported_task_types": ["text_embedding"]
|
||||
"supported_task_types": ["text_embedding", "completion", "chat_completion"]
|
||||
},
|
||||
"api_version": {
|
||||
"description": "The IBM Watsonx API version ID to use.",
|
||||
"description": "The IBM watsonx API version ID to use.",
|
||||
"label": "API Version",
|
||||
"required": true,
|
||||
"sensitive": false,
|
||||
"updatable": false,
|
||||
"type": "str",
|
||||
"supported_task_types": ["text_embedding"]
|
||||
"supported_task_types": ["text_embedding", "completion", "chat_completion"]
|
||||
},
|
||||
"max_input_tokens": {
|
||||
"description": "Allows you to specify the maximum number of tokens per input.",
|
||||
|
@ -964,7 +964,7 @@ public class IbmWatsonxServiceTests extends ESTestCase {
|
|||
"sensitive": false,
|
||||
"updatable": false,
|
||||
"type": "str",
|
||||
"supported_task_types": ["text_embedding"]
|
||||
"supported_task_types": ["text_embedding", "completion", "chat_completion"]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.ibmwatsonx.action;
|
||||
|
||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||
import org.elasticsearch.xpack.inference.services.ChatCompletionActionTests;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.request.IbmWatsonxChatCompletionRequest;
|
||||
|
||||
import java.net.URI;
|
||||
import java.net.URISyntaxException;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
|
||||
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.action.IbmWatsonxActionCreator.COMPLETION_HANDLER;
|
||||
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.action.IbmWatsonxActionCreator.USER_ROLE;
|
||||
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModelTests.createModel;
|
||||
|
||||
public class IbmWatsonxChatCompletionActionTests extends ChatCompletionActionTests {
|
||||
public static final URI TEST_URI = URI.create("abc.com");
|
||||
|
||||
protected ExecutableAction createAction(String url, Sender sender) throws URISyntaxException {
|
||||
var model = createModel(TEST_URI, randomAlphaOfLength(8), randomAlphaOfLength(8), randomAlphaOfLength(8), randomAlphaOfLength(8));
|
||||
var manager = new GenericRequestManager<>(
|
||||
threadPool,
|
||||
model,
|
||||
COMPLETION_HANDLER,
|
||||
inputs -> new IbmWatsonxChatCompletionRequest(new UnifiedChatInput(inputs, USER_ROLE), model),
|
||||
ChatCompletionInput.class
|
||||
);
|
||||
var errorMessage = constructFailedToSendRequestMessage("watsonx chat completions");
|
||||
return new SingleInputSenderExecutableAction(sender, manager, errorMessage, "watsonx chat completions");
|
||||
}
|
||||
|
||||
protected String getFailedToSendError() {
|
||||
return "Failed to send watsonx chat completions request. Cause: failed";
|
||||
}
|
||||
|
||||
protected String getOneInputError() {
|
||||
return "watsonx chat completions only accepts 1 input";
|
||||
}
|
||||
}
|
|
@ -180,7 +180,7 @@ public class IbmWatsonxEmbeddingsActionTests extends ESTestCase {
|
|||
|
||||
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
|
||||
|
||||
assertThat(thrownException.getMessage(), is("Failed to send IBM Watsonx embeddings request. Cause: failed"));
|
||||
assertThat(thrownException.getMessage(), is("Failed to send IBM watsonx embeddings request. Cause: failed"));
|
||||
}
|
||||
|
||||
public void testExecute_ThrowsException() {
|
||||
|
@ -204,7 +204,7 @@ public class IbmWatsonxEmbeddingsActionTests extends ESTestCase {
|
|||
|
||||
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
|
||||
|
||||
assertThat(thrownException.getMessage(), is("Failed to send IBM Watsonx embeddings request. Cause: failed"));
|
||||
assertThat(thrownException.getMessage(), is("Failed to send IBM watsonx embeddings request. Cause: failed"));
|
||||
}
|
||||
|
||||
private ExecutableAction createAction(
|
||||
|
@ -218,7 +218,7 @@ public class IbmWatsonxEmbeddingsActionTests extends ESTestCase {
|
|||
) {
|
||||
var model = createModel(modelName, projectId, uri, apiVersion, apiKey, url);
|
||||
var requestManager = new IbmWatsonxEmbeddingsRequestManagerWithoutAuth(model, TruncatorTests.createTruncator(), threadPool);
|
||||
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("IBM Watsonx embeddings");
|
||||
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage("IBM watsonx embeddings");
|
||||
return new SenderExecutableAction(sender, requestManager, failedToSendRequestErrorMessage);
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.ibmwatsonx.completion;
|
||||
|
||||
import org.elasticsearch.common.settings.SecureString;
|
||||
import org.elasticsearch.inference.TaskType;
|
||||
import org.elasticsearch.inference.UnifiedCompletionRequest;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings;
|
||||
|
||||
import java.net.URI;
|
||||
import java.net.URISyntaxException;
|
||||
import java.util.List;
|
||||
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class IbmWatsonxChatCompletionModelTests extends ESTestCase {
|
||||
private static final URI TEST_URI = URI.create("abc.com");
|
||||
|
||||
public static IbmWatsonxChatCompletionModel createModel(URI uri, String apiVersion, String modelId, String projectId, String apiKey)
|
||||
throws URISyntaxException {
|
||||
return new IbmWatsonxChatCompletionModel(
|
||||
"id",
|
||||
TaskType.COMPLETION,
|
||||
"service",
|
||||
new IbmWatsonxChatCompletionServiceSettings(uri, apiVersion, modelId, projectId, null),
|
||||
new DefaultSecretSettings(new SecureString(apiKey.toCharArray()))
|
||||
);
|
||||
}
|
||||
|
||||
public void testOverrideWith_UnifiedCompletionRequest_OverridesExistingModelId() throws URISyntaxException {
|
||||
var model = createModel(TEST_URI, "apiVersion", "modelId", "projectId", "apiKey");
|
||||
var request = new UnifiedCompletionRequest(
|
||||
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
|
||||
"different_model",
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null
|
||||
);
|
||||
|
||||
var overriddenModel = IbmWatsonxChatCompletionModel.of(model, request);
|
||||
|
||||
assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model"));
|
||||
}
|
||||
|
||||
public void testOverrideWith_UnifiedCompletionRequest_OverridesNullModelId() throws URISyntaxException {
|
||||
var model = createModel(TEST_URI, "apiVersion", null, "projectId", "apiKey");
|
||||
var request = new UnifiedCompletionRequest(
|
||||
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
|
||||
"different_model",
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null
|
||||
);
|
||||
|
||||
var overriddenModel = IbmWatsonxChatCompletionModel.of(model, request);
|
||||
|
||||
assertThat(overriddenModel.getServiceSettings().modelId(), is("different_model"));
|
||||
}
|
||||
|
||||
public void testOverrideWith_UnifiedCompletionRequest_KeepsNullIfNoModelIdProvided() throws URISyntaxException {
|
||||
var model = createModel(TEST_URI, "apiVersion", null, "projectId", "apiKey");
|
||||
var request = new UnifiedCompletionRequest(
|
||||
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null
|
||||
);
|
||||
|
||||
var overriddenModel = IbmWatsonxChatCompletionModel.of(model, request);
|
||||
|
||||
assertNull(overriddenModel.getServiceSettings().modelId());
|
||||
}
|
||||
|
||||
public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() throws URISyntaxException {
|
||||
var model = createModel(TEST_URI, "apiVersion", "modelId", "projectId", "apiKey");
|
||||
var request = new UnifiedCompletionRequest(
|
||||
List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null)),
|
||||
null, // not overriding model
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null
|
||||
);
|
||||
|
||||
var overriddenModel = IbmWatsonxChatCompletionModel.of(model, request);
|
||||
|
||||
assertThat(overriddenModel.getServiceSettings().modelId(), is("modelId"));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,173 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.ibmwatsonx.completion;
|
||||
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.ValidationException;
|
||||
import org.elasticsearch.common.io.stream.Writeable;
|
||||
import org.elasticsearch.test.AbstractWireSerializingTestCase;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.XContentFactory;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
|
||||
import org.elasticsearch.xpack.inference.services.ServiceFields;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields;
|
||||
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString;
|
||||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class IbmWatsonxChatCompletionServiceSettingsTests extends AbstractWireSerializingTestCase<IbmWatsonxChatCompletionServiceSettings> {
|
||||
private static final URI TEST_URI = URI.create("abc.com");
|
||||
|
||||
private static IbmWatsonxChatCompletionServiceSettings createRandom() {
|
||||
return new IbmWatsonxChatCompletionServiceSettings(
|
||||
TEST_URI,
|
||||
randomAlphaOfLength(8),
|
||||
randomAlphaOfLength(8),
|
||||
randomAlphaOfLength(8),
|
||||
randomFrom(RateLimitSettingsTests.createRandom(), null)
|
||||
);
|
||||
}
|
||||
|
||||
private IbmWatsonxChatCompletionServiceSettings getServiceSettings(Map<String, String> map) {
|
||||
return IbmWatsonxChatCompletionServiceSettings.fromMap(new HashMap<>(map), ConfigurationParseContext.PERSISTENT);
|
||||
}
|
||||
|
||||
public void testFromMap_WithAllParameters_CreatesSettingsCorrectly() {
|
||||
var model = randomAlphaOfLength(8);
|
||||
var projectId = randomAlphaOfLength(8);
|
||||
var apiVersion = randomAlphaOfLength(8);
|
||||
|
||||
var serviceSettings = getServiceSettings(
|
||||
Map.of(
|
||||
ServiceFields.URL,
|
||||
TEST_URI.toString(),
|
||||
IbmWatsonxServiceFields.API_VERSION,
|
||||
apiVersion,
|
||||
ServiceFields.MODEL_ID,
|
||||
model,
|
||||
IbmWatsonxServiceFields.PROJECT_ID,
|
||||
projectId
|
||||
)
|
||||
);
|
||||
assertThat(serviceSettings, is(new IbmWatsonxChatCompletionServiceSettings(TEST_URI, apiVersion, model, projectId, null)));
|
||||
}
|
||||
|
||||
public void testFromMap_Fails_WithoutRequiredParam_Url() {
|
||||
var ex = expectThrows(
|
||||
ValidationException.class,
|
||||
() -> getServiceSettings(
|
||||
Map.of(
|
||||
IbmWatsonxServiceFields.API_VERSION,
|
||||
randomAlphaOfLength(8),
|
||||
ServiceFields.MODEL_ID,
|
||||
randomAlphaOfLength(8),
|
||||
IbmWatsonxServiceFields.PROJECT_ID,
|
||||
randomAlphaOfLength(8)
|
||||
)
|
||||
)
|
||||
);
|
||||
assertThat(ex.getMessage(), equalTo(generateErrorMessage("url")));
|
||||
}
|
||||
|
||||
public void testFromMap_Fails_WithoutRequiredParam_ApiVersion() {
|
||||
var ex = expectThrows(
|
||||
ValidationException.class,
|
||||
() -> getServiceSettings(
|
||||
Map.of(
|
||||
ServiceFields.URL,
|
||||
TEST_URI.toString(),
|
||||
ServiceFields.MODEL_ID,
|
||||
randomAlphaOfLength(8),
|
||||
IbmWatsonxServiceFields.PROJECT_ID,
|
||||
randomAlphaOfLength(8)
|
||||
)
|
||||
)
|
||||
);
|
||||
assertThat(ex.getMessage(), equalTo(generateErrorMessage("api_version")));
|
||||
}
|
||||
|
||||
public void testFromMap_Fails_WithoutRequiredParam_ModelId() {
|
||||
var ex = expectThrows(
|
||||
ValidationException.class,
|
||||
() -> getServiceSettings(
|
||||
Map.of(
|
||||
ServiceFields.URL,
|
||||
TEST_URI.toString(),
|
||||
IbmWatsonxServiceFields.API_VERSION,
|
||||
randomAlphaOfLength(8),
|
||||
IbmWatsonxServiceFields.PROJECT_ID,
|
||||
randomAlphaOfLength(8)
|
||||
)
|
||||
)
|
||||
);
|
||||
assertThat(ex.getMessage(), equalTo(generateErrorMessage("model_id")));
|
||||
}
|
||||
|
||||
public void testFromMap_Fails_WithoutRequiredParam_ProjectId() {
|
||||
var ex = expectThrows(
|
||||
ValidationException.class,
|
||||
() -> getServiceSettings(
|
||||
Map.of(
|
||||
ServiceFields.URL,
|
||||
TEST_URI.toString(),
|
||||
IbmWatsonxServiceFields.API_VERSION,
|
||||
randomAlphaOfLength(8),
|
||||
ServiceFields.MODEL_ID,
|
||||
randomAlphaOfLength(8)
|
||||
)
|
||||
)
|
||||
);
|
||||
assertThat(ex.getMessage(), equalTo(generateErrorMessage("project_id")));
|
||||
}
|
||||
|
||||
private String generateErrorMessage(String field) {
|
||||
return "Validation Failed: 1: [service_settings] does not contain the required setting [" + field + "];";
|
||||
}
|
||||
|
||||
public void testToXContent_WritesAllValues() throws IOException {
|
||||
var entity = new IbmWatsonxChatCompletionServiceSettings(TEST_URI, "2024-05-02", "model", "project_id", null);
|
||||
|
||||
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
|
||||
entity.toXContent(builder, null);
|
||||
String xContentResult = Strings.toString(builder);
|
||||
|
||||
assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString("""
|
||||
{
|
||||
"url":"abc.com",
|
||||
"api_version":"2024-05-02",
|
||||
"model_id":"model",
|
||||
"project_id":"project_id",
|
||||
"rate_limit": {
|
||||
"requests_per_minute":120
|
||||
}
|
||||
}"""));
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Writeable.Reader<IbmWatsonxChatCompletionServiceSettings> instanceReader() {
|
||||
return IbmWatsonxChatCompletionServiceSettings::new;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected IbmWatsonxChatCompletionServiceSettings createTestInstance() {
|
||||
return createRandom();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected IbmWatsonxChatCompletionServiceSettings mutateInstance(IbmWatsonxChatCompletionServiceSettings instance) throws IOException {
|
||||
return randomValueOtherThan(instance, IbmWatsonxChatCompletionServiceSettingsTests::createRandom);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,66 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.ibmwatsonx.request;
|
||||
|
||||
import org.elasticsearch.common.Strings;
|
||||
import org.elasticsearch.common.xcontent.XContentHelper;
|
||||
import org.elasticsearch.inference.UnifiedCompletionRequest;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xcontent.ToXContent;
|
||||
import org.elasticsearch.xcontent.XContentBuilder;
|
||||
import org.elasticsearch.xcontent.json.JsonXContent;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModel;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
import java.net.URISyntaxException;
|
||||
import java.util.ArrayList;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModelTests.createModel;
|
||||
|
||||
public class IbmWatsonxChatCompletionRequestEntityTests extends ESTestCase {
|
||||
|
||||
private static final String ROLE = "user";
|
||||
|
||||
public void testModelUserFieldsSerialization() throws IOException, URISyntaxException {
|
||||
UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message(
|
||||
new UnifiedCompletionRequest.ContentString("test content"),
|
||||
ROLE,
|
||||
null,
|
||||
null
|
||||
);
|
||||
var messageList = new ArrayList<UnifiedCompletionRequest.Message>();
|
||||
messageList.add(message);
|
||||
|
||||
var unifiedRequest = UnifiedCompletionRequest.of(messageList);
|
||||
|
||||
UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true);
|
||||
IbmWatsonxChatCompletionModel model = createModel(new URI("abc.com"), "apiVersion", "modelId", "projectId", "apiKey");
|
||||
|
||||
IbmWatsonxChatCompletionRequestEntity entity = new IbmWatsonxChatCompletionRequestEntity(unifiedChatInput, model);
|
||||
|
||||
XContentBuilder builder = JsonXContent.contentBuilder();
|
||||
entity.toXContent(builder, ToXContent.EMPTY_PARAMS);
|
||||
String expectedJson = """
|
||||
{
|
||||
"project_id": "projectId",
|
||||
"messages": [
|
||||
{
|
||||
"content": "test content",
|
||||
"role": "user"
|
||||
}
|
||||
],
|
||||
"model": "modelId",
|
||||
"n": 1,
|
||||
"stream": true
|
||||
}
|
||||
""";
|
||||
assertEquals(XContentHelper.stripWhitespace(expectedJson), Strings.toString(builder));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,106 @@
|
|||
/*
|
||||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
|
||||
* or more contributor license agreements. Licensed under the Elastic License
|
||||
* 2.0; you may not use this file except in compliance with the Elastic License
|
||||
* 2.0.
|
||||
*/
|
||||
|
||||
package org.elasticsearch.xpack.inference.services.ibmwatsonx.request;
|
||||
|
||||
import org.apache.http.HttpHeaders;
|
||||
import org.apache.http.client.methods.HttpPost;
|
||||
import org.elasticsearch.core.Nullable;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModel;
|
||||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.completion.IbmWatsonxChatCompletionModelTests;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
import java.net.URISyntaxException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
import static org.hamcrest.Matchers.aMapWithSize;
|
||||
import static org.hamcrest.Matchers.instanceOf;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
|
||||
public class IbmWatsonxChatCompletionRequestTests extends ESTestCase {
|
||||
private static final String AUTH_HEADER_VALUE = "foo";
|
||||
private static final String API_COMPLETIONS_PATH = "https://abc.com/ml/v1/text/chat?version=apiVersion";
|
||||
|
||||
public void testCreateRequest_WithStreaming() throws IOException, URISyntaxException {
|
||||
assertCreateRequestWithStreaming(true);
|
||||
}
|
||||
|
||||
public void testCreateRequest_WithNoStreaming() throws IOException, URISyntaxException {
|
||||
assertCreateRequestWithStreaming(false);
|
||||
}
|
||||
|
||||
public void testTruncate_DoesNotReduceInputTextSize() throws IOException, URISyntaxException {
|
||||
String input = randomAlphaOfLength(5);
|
||||
String model = randomAlphaOfLength(5);
|
||||
|
||||
var request = createRequest(randomAlphaOfLength(5), input, model, true);
|
||||
var truncatedRequest = request.truncate();
|
||||
assertThat(request.getURI().toString(), is(API_COMPLETIONS_PATH));
|
||||
|
||||
var httpRequest = truncatedRequest.createHttpRequest();
|
||||
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||
|
||||
var httpPost = (HttpPost) httpRequest.httpRequestBase();
|
||||
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
||||
assertThat(requestMap, aMapWithSize(5));
|
||||
|
||||
assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", input))));
|
||||
assertThat(requestMap.get("model"), is(model));
|
||||
assertThat(requestMap.get("n"), is(1));
|
||||
assertTrue((Boolean) requestMap.get("stream"));
|
||||
assertNull(requestMap.get("stream_options"));
|
||||
}
|
||||
|
||||
public void testTruncationInfo_ReturnsNull() throws URISyntaxException {
|
||||
var request = createRequest(randomAlphaOfLength(5), randomAlphaOfLength(5), randomAlphaOfLength(5), true);
|
||||
assertNull(request.getTruncationInfo());
|
||||
}
|
||||
|
||||
public static IbmWatsonxChatCompletionRequest createRequest(String apiKey, String input, @Nullable String model)
|
||||
throws URISyntaxException {
|
||||
return createRequest(apiKey, input, model, false);
|
||||
}
|
||||
|
||||
public static IbmWatsonxChatCompletionRequest createRequest(String apiKey, String input, @Nullable String model, boolean stream)
|
||||
throws URISyntaxException {
|
||||
var chatCompletionModel = IbmWatsonxChatCompletionModelTests.createModel(
|
||||
new URI("abc.com"),
|
||||
"apiVersion",
|
||||
model,
|
||||
randomAlphaOfLength(5),
|
||||
apiKey
|
||||
);
|
||||
return new IbmWatsonxChatCompletionWithoutAuthRequest(new UnifiedChatInput(List.of(input), "user", stream), chatCompletionModel);
|
||||
}
|
||||
|
||||
private static class IbmWatsonxChatCompletionWithoutAuthRequest extends IbmWatsonxChatCompletionRequest {
|
||||
IbmWatsonxChatCompletionWithoutAuthRequest(UnifiedChatInput input, IbmWatsonxChatCompletionModel model) {
|
||||
super(input, model);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void decorateWithAuth(HttpPost httpPost) {
|
||||
httpPost.setHeader(HttpHeaders.AUTHORIZATION, AUTH_HEADER_VALUE);
|
||||
}
|
||||
}
|
||||
|
||||
private void assertCreateRequestWithStreaming(boolean isStreaming) throws URISyntaxException, IOException {
|
||||
var request = createRequest(randomAlphaOfLength(5), randomAlphaOfLength(5), randomAlphaOfLength(5), isStreaming);
|
||||
var httpRequest = request.createHttpRequest();
|
||||
|
||||
assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class));
|
||||
var httpPost = (HttpPost) httpRequest.httpRequestBase();
|
||||
|
||||
var requestMap = entityAsMap(httpPost.getEntity().getContent());
|
||||
assertThat(requestMap.get("stream"), is(isStreaming));
|
||||
}
|
||||
}
|
|
@ -112,6 +112,6 @@ public class IbmWatsonxEmbeddingsResponseEntityTests extends ESTestCase {
|
|||
)
|
||||
);
|
||||
|
||||
assertThat(thrownException.getMessage(), is("Failed to find required field [results] in IBM Watsonx embeddings response"));
|
||||
assertThat(thrownException.getMessage(), is("Failed to find required field [results] in IBM watsonx embeddings response"));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,42 +8,28 @@
|
|||
package org.elasticsearch.xpack.inference.services.mistral.action;
|
||||
|
||||
import org.apache.http.HttpHeaders;
|
||||
import org.elasticsearch.ElasticsearchException;
|
||||
import org.elasticsearch.ElasticsearchStatusException;
|
||||
import org.elasticsearch.action.ActionListener;
|
||||
import org.elasticsearch.action.support.PlainActionFuture;
|
||||
import org.elasticsearch.common.settings.Settings;
|
||||
import org.elasticsearch.core.TimeValue;
|
||||
import org.elasticsearch.inference.InferenceServiceResults;
|
||||
import org.elasticsearch.rest.RestStatus;
|
||||
import org.elasticsearch.test.ESTestCase;
|
||||
import org.elasticsearch.test.http.MockRequest;
|
||||
import org.elasticsearch.test.http.MockResponse;
|
||||
import org.elasticsearch.test.http.MockWebServer;
|
||||
import org.elasticsearch.threadpool.ThreadPool;
|
||||
import org.elasticsearch.xcontent.XContentType;
|
||||
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
|
||||
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction;
|
||||
import org.elasticsearch.xpack.inference.external.http.HttpClientManager;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.GenericRequestManager;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
|
||||
import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput;
|
||||
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
|
||||
import org.elasticsearch.xpack.inference.services.ChatCompletionActionTests;
|
||||
import org.elasticsearch.xpack.inference.services.mistral.request.completion.MistralChatCompletionRequest;
|
||||
import org.junit.After;
|
||||
import org.junit.Before;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.URISyntaxException;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
|
||||
import static org.elasticsearch.xpack.core.inference.results.ChatCompletionResultsTests.buildExpectationCompletion;
|
||||
import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
|
||||
import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
|
||||
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
|
||||
import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap;
|
||||
|
@ -56,64 +42,15 @@ import static org.elasticsearch.xpack.inference.services.mistral.completion.Mist
|
|||
import static org.hamcrest.Matchers.equalTo;
|
||||
import static org.hamcrest.Matchers.hasSize;
|
||||
import static org.hamcrest.Matchers.is;
|
||||
import static org.mockito.ArgumentMatchers.any;
|
||||
import static org.mockito.Mockito.doAnswer;
|
||||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.Mockito.mock;
|
||||
|
||||
public class MistralChatCompletionActionTests extends ESTestCase {
|
||||
private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS);
|
||||
private final MockWebServer webServer = new MockWebServer();
|
||||
private ThreadPool threadPool;
|
||||
private HttpClientManager clientManager;
|
||||
|
||||
@Before
|
||||
public void init() throws Exception {
|
||||
webServer.start();
|
||||
threadPool = createThreadPool(inferenceUtilityPool());
|
||||
clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class));
|
||||
}
|
||||
|
||||
@After
|
||||
public void shutdown() throws IOException {
|
||||
clientManager.close();
|
||||
terminate(threadPool);
|
||||
webServer.close();
|
||||
}
|
||||
|
||||
public void testExecute_ReturnsSuccessfulResponse() throws IOException {
|
||||
public class MistralChatCompletionActionTests extends ChatCompletionActionTests {
|
||||
public void testExecute_ReturnsSuccessfulResponse() throws IOException, URISyntaxException {
|
||||
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
|
||||
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
{
|
||||
"id": "9d80f26810ac4e9582f927fcf0512ec7",
|
||||
"object": "chat.completion",
|
||||
"created": 1748596419,
|
||||
"model": "mistral-small-latest",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"tool_calls": null,
|
||||
"content": "result content"
|
||||
},
|
||||
"finish_reason": "length",
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"total_tokens": 11,
|
||||
"completion_tokens": 1
|
||||
}
|
||||
}
|
||||
""";
|
||||
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(getResponseJson()));
|
||||
|
||||
var action = createAction(getUrl(webServer), sender);
|
||||
|
||||
|
@ -140,87 +77,7 @@ public class MistralChatCompletionActionTests extends ESTestCase {
|
|||
}
|
||||
}
|
||||
|
||||
public void testExecute_ThrowsElasticsearchException() {
|
||||
var sender = mock(Sender.class);
|
||||
doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any());
|
||||
|
||||
var action = createAction(getUrl(webServer), sender);
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||
|
||||
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
|
||||
|
||||
assertThat(thrownException.getMessage(), is("failed"));
|
||||
}
|
||||
|
||||
public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() {
|
||||
var sender = mock(Sender.class);
|
||||
|
||||
doAnswer(invocation -> {
|
||||
ActionListener<InferenceServiceResults> listener = invocation.getArgument(3);
|
||||
listener.onFailure(new IllegalStateException("failed"));
|
||||
|
||||
return Void.TYPE;
|
||||
}).when(sender).send(any(), any(), any(), any());
|
||||
|
||||
var action = createAction(getUrl(webServer), sender);
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||
|
||||
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
|
||||
|
||||
assertThat(thrownException.getMessage(), is("Failed to send mistral chat completions request. Cause: failed"));
|
||||
}
|
||||
|
||||
public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOException {
|
||||
var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);
|
||||
|
||||
try (var sender = createSender(senderFactory)) {
|
||||
sender.start();
|
||||
|
||||
String responseJson = """
|
||||
{
|
||||
"id": "9d80f26810ac4e9582f927fcf0512ec7",
|
||||
"object": "chat.completion",
|
||||
"created": 1748596419,
|
||||
"model": "mistral-small-latest",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"tool_calls": null,
|
||||
"content": "result content"
|
||||
},
|
||||
"finish_reason": "length",
|
||||
"logprobs": null
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"total_tokens": 11,
|
||||
"completion_tokens": 1
|
||||
}
|
||||
}
|
||||
""";
|
||||
|
||||
webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson));
|
||||
|
||||
var action = createAction(getUrl(webServer), sender);
|
||||
|
||||
PlainActionFuture<InferenceServiceResults> listener = new PlainActionFuture<>();
|
||||
action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener);
|
||||
|
||||
var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT));
|
||||
|
||||
assertThat(thrownException.getMessage(), is("mistral chat completions only accepts 1 input"));
|
||||
assertThat(thrownException.status(), is(RestStatus.BAD_REQUEST));
|
||||
}
|
||||
}
|
||||
|
||||
private ExecutableAction createAction(String url, Sender sender) {
|
||||
protected ExecutableAction createAction(String url, Sender sender) {
|
||||
var model = createCompletionModel("secret", "model");
|
||||
model.setURI(url);
|
||||
var manager = new GenericRequestManager<>(
|
||||
|
@ -233,4 +90,12 @@ public class MistralChatCompletionActionTests extends ESTestCase {
|
|||
var errorMessage = constructFailedToSendRequestMessage("mistral chat completions");
|
||||
return new SingleInputSenderExecutableAction(sender, manager, errorMessage, "mistral chat completions");
|
||||
}
|
||||
|
||||
protected String getFailedToSendError() {
|
||||
return "Failed to send mistral chat completions request. Cause: failed";
|
||||
}
|
||||
|
||||
protected String getOneInputError() {
|
||||
return "mistral chat completions only accepts 1 input";
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue