Skip to content

Commit

Permalink
Integrate watsonx for re-ranking task (elastic#117176)
Browse files Browse the repository at this point in the history
* Integrate watsonx reranking to inference api

* Add api_version to the watsonx api call

* Fix the return_doc option

* Add top_n parameter to task_settings

* Add truncate_input_tokens parameter to task_settings

* Add test for IbmWatonxRankedResponseEntity

* Add test for IbmWatonxRankedRequestEntity

* Add test for IbmWatonxRankedRequest

* [CI] Auto commit changes from spotless

* Add changelog

* Fix transport version

* Add test for IbmWatsonxService

* Remove canHandleStreamingResponses

* Add requireNonNull for modelId and projectId

* Remove maxInputToken method

* Convert all optionals to required

* [CI] Auto commit changes from spotless

* Set minimal_supported version to be ML_INFERENCE_IBM_WATSONX_RERANK_ADDED

* Remove extraction of unused fields from IbmWatsonxRerankServiceSettings

* Add space

* Add space

---------

Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co>
(cherry picked from commit 09b1c6d)
  • Loading branch information
saikatsarkar056 committed Jan 31, 2025
1 parent 2b7d91a commit d2b5f05
Show file tree
Hide file tree
Showing 21 changed files with 1,370 additions and 2 deletions.
5 changes: 5 additions & 0 deletions docs/changelog/117176.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 117176
summary: Integrate IBM watsonx to Inference API for re-ranking task
area: Experiences
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ static TransportVersion def(int id) {
public static final TransportVersion LINEAR_RETRIEVER_SUPPORT = def(8_837_00_0);
public static final TransportVersion TIMEOUT_GET_PARAM_FOR_RESOLVE_CLUSTER = def(8_838_00_0);
public static final TransportVersion INFERENCE_REQUEST_ADAPTIVE_RATE_LIMITING = def(8_839_00_0);
public static final TransportVersion ML_INFERENCE_IBM_WATSONX_RERANK_ADDED = def(8_840_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@
import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings;
import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankServiceSettings;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;
import org.elasticsearch.xpack.inference.services.jinaai.JinaAIServiceSettings;
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsServiceSettings;
import org.elasticsearch.xpack.inference.services.jinaai.embeddings.JinaAIEmbeddingsTaskSettings;
Expand Down Expand Up @@ -364,6 +366,17 @@ private static void addIbmWatsonxNamedWritables(List<NamedWriteableRegistry.Entr
IbmWatsonxEmbeddingsServiceSettings::new
)
);

namedWriteables.add(
new NamedWriteableRegistry.Entry(
ServiceSettings.class,
IbmWatsonxRerankServiceSettings.NAME,
IbmWatsonxRerankServiceSettings::new
)
);
namedWriteables.add(
new NamedWriteableRegistry.Entry(TaskSettings.class, IbmWatsonxRerankTaskSettings.NAME, IbmWatsonxRerankTaskSettings::new)
);
}

private static void addGoogleVertexAiNamedWriteables(List<NamedWriteableRegistry.Entry> namedWriteables) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,18 @@
import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
import org.elasticsearch.xpack.inference.external.http.sender.IbmWatsonxEmbeddingsRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.IbmWatsonxRerankRequestManager;
import org.elasticsearch.xpack.inference.external.http.sender.Sender;
import org.elasticsearch.xpack.inference.services.ServiceComponents;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;

import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;

public class IbmWatsonxActionCreator implements IbmWatsonxActionVisitor {

private final Sender sender;
private final ServiceComponents serviceComponents;

Expand All @@ -41,6 +42,17 @@ public ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map<String, Obje
);
}

@Override
public ExecutableAction create(IbmWatsonxRerankModel model, Map<String, Object> taskSettings) {
var overriddenModel = IbmWatsonxRerankModel.of(model, taskSettings);
var requestCreator = IbmWatsonxRerankRequestManager.of(overriddenModel, serviceComponents.threadPool());
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(
overriddenModel.getServiceSettings().uri(),
"Ibm Watsonx rerank"
);
return new SenderExecutableAction(sender, requestCreator, failedToSendRequestErrorMessage);
}

protected IbmWatsonxEmbeddingsRequestManager getEmbeddingsRequestManager(
IbmWatsonxEmbeddingsModel model,
Truncator truncator,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@

import org.elasticsearch.xpack.inference.external.action.ExecutableAction;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;

import java.util.Map;

public interface IbmWatsonxActionVisitor {
ExecutableAction create(IbmWatsonxEmbeddingsModel model, Map<String, Object> taskSettings);

ExecutableAction create(IbmWatsonxRerankModel model, Map<String, Object> taskSettings);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.http.sender;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler;
import org.elasticsearch.xpack.inference.external.ibmwatsonx.IbmWatsonxResponseHandler;
import org.elasticsearch.xpack.inference.external.request.ibmwatsonx.IbmWatsonxRerankRequest;
import org.elasticsearch.xpack.inference.external.response.ibmwatsonx.IbmWatsonxRankedResponseEntity;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;

import java.util.List;
import java.util.Objects;
import java.util.function.Supplier;

public class IbmWatsonxRerankRequestManager extends IbmWatsonxRequestManager {
private static final Logger logger = LogManager.getLogger(IbmWatsonxRerankRequestManager.class);
private static final ResponseHandler HANDLER = createIbmWatsonxResponseHandler();

private static ResponseHandler createIbmWatsonxResponseHandler() {
return new IbmWatsonxResponseHandler(
"ibm watsonx rerank",
(request, response) -> IbmWatsonxRankedResponseEntity.fromResponse(response)
);
}

public static IbmWatsonxRerankRequestManager of(IbmWatsonxRerankModel model, ThreadPool threadPool) {
return new IbmWatsonxRerankRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool));
}

private final IbmWatsonxRerankModel model;

public IbmWatsonxRerankRequestManager(IbmWatsonxRerankModel model, ThreadPool threadPool) {
super(threadPool, model);
this.model = model;
}

@Override
public void execute(
InferenceInputs inferenceInputs,
RequestSender requestSender,
Supplier<Boolean> hasRequestCompletedFunction,
ActionListener<InferenceServiceResults> listener
) {
var rerankInput = QueryAndDocsInputs.of(inferenceInputs);

execute(
new ExecutableInferenceRequest(
requestSender,
logger,
getRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model),
HANDLER,
hasRequestCompletedFunction,
listener
)
);
}

protected IbmWatsonxRerankRequest getRerankRequest(String query, List<String> chunks, IbmWatsonxRerankModel model) {
return new IbmWatsonxRerankRequest(query, chunks, model);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import static org.elasticsearch.core.Strings.format;

public class IbmWatsonxResponseHandler extends BaseResponseHandler {

public IbmWatsonxResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, IbmWatsonxErrorResponseEntity::fromResponse);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.ibmwatsonx;

import org.apache.http.HttpHeaders;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.ByteArrayEntity;
import org.elasticsearch.common.Strings;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.inference.external.request.HttpRequest;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;

import java.net.URI;
import java.net.URISyntaxException;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Objects;

public class IbmWatsonxRerankRequest implements IbmWatsonxRequest {

private final String query;
private final List<String> input;
private final IbmWatsonxRerankTaskSettings taskSettings;
private final IbmWatsonxRerankModel model;

public IbmWatsonxRerankRequest(String query, List<String> input, IbmWatsonxRerankModel model) {
Objects.requireNonNull(model);

this.input = Objects.requireNonNull(input);
this.query = Objects.requireNonNull(query);
taskSettings = model.getTaskSettings();
this.model = model;
}

@Override
public HttpRequest createHttpRequest() {
URI uri;

try {
uri = new URI(model.uri().toString());
} catch (URISyntaxException ex) {
throw new IllegalArgumentException("cannot parse URI patter");
}

HttpPost httpPost = new HttpPost(uri);

ByteArrayEntity byteEntity = new ByteArrayEntity(
Strings.toString(
new IbmWatsonxRerankRequestEntity(
query,
input,
taskSettings,
model.getServiceSettings().modelId(),
model.getServiceSettings().projectId()
)
).getBytes(StandardCharsets.UTF_8)
);

httpPost.setEntity(byteEntity);
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType());

decorateWithAuth(httpPost);

return new HttpRequest(httpPost, getInferenceEntityId());
}

public void decorateWithAuth(HttpPost httpPost) {
IbmWatsonxRequest.decorateWithBearerToken(httpPost, model.getSecretSettings(), model.getInferenceEntityId());
}

@Override
public String getInferenceEntityId() {
return model.getInferenceEntityId();
}

@Override
public URI getURI() {
return model.uri();
}

@Override
public Request truncate() {
return this;
}

public String getQuery() {
return query;
}

public List<String> getInput() {
return input;
}

public IbmWatsonxRerankModel getModel() {
return model;
}

@Override
public boolean[] getTruncationInfo() {
return null;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.inference.external.request.ibmwatsonx;

import org.elasticsearch.xcontent.ToXContentObject;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings;

import java.io.IOException;
import java.util.List;
import java.util.Objects;

public record IbmWatsonxRerankRequestEntity(
String query,
List<String> inputs,
IbmWatsonxRerankTaskSettings taskSettings,
String modelId,
String projectId
) implements ToXContentObject {

private static final String INPUTS_FIELD = "inputs";
private static final String QUERY_FIELD = "query";
private static final String MODEL_ID_FIELD = "model_id";
private static final String PROJECT_ID_FIELD = "project_id";

public IbmWatsonxRerankRequestEntity {
Objects.requireNonNull(query);
Objects.requireNonNull(inputs);
Objects.requireNonNull(modelId);
Objects.requireNonNull(projectId);
Objects.requireNonNull(taskSettings);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();

builder.field(MODEL_ID_FIELD, modelId);
builder.field(QUERY_FIELD, query);
builder.startArray(INPUTS_FIELD);
for (String input : inputs) {
builder.startObject();
builder.field("text", input);
builder.endObject();
}
builder.endArray();
builder.field(PROJECT_ID_FIELD, projectId);

builder.startObject("parameters");
{
if (taskSettings.getTruncateInputTokens() != null) {
builder.field("truncate_input_tokens", taskSettings.getTruncateInputTokens());
}

builder.startObject("return_options");
{
if (taskSettings.getDoesReturnDocuments() != null) {
builder.field("inputs", taskSettings.getDoesReturnDocuments());
}
if (taskSettings.getTopNDocumentsOnly() != null) {
builder.field("top_n", taskSettings.getTopNDocumentsOnly());
}
}
builder.endObject();
}
builder.endObject();

builder.endObject();

return builder;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ public class IbmWatsonxUtils {
public static final String V1 = "v1";
public static final String TEXT = "text";
public static final String EMBEDDINGS = "embeddings";
public static final String RERANKS = "reranks";

private IbmWatsonxUtils() {}

Expand Down
Loading

0 comments on commit d2b5f05

Please sign in to comment.