forked from elastic/elasticsearch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Integrate watsonx for re-ranking task (elastic#117176)
* Integrate watsonx reranking to inference api * Add api_version to the watsonx api call * Fix the return_doc option * Add top_n parameter to task_settings * Add truncate_input_tokens parameter to task_settings * Add test for IbmWatonxRankedResponseEntity * Add test for IbmWatonxRankedRequestEntity * Add test for IbmWatonxRankedRequest * [CI] Auto commit changes from spotless * Add changelog * Fix transport version * Add test for IbmWatsonxService * Remove canHandleStreamingResponses * Add requireNonNull for modelId and projectId * Remove maxInputToken method * Convert all optionals to required * [CI] Auto commit changes from spotless * Set minimal_supported version to be ML_INFERENCE_IBM_WATSONX_RERANK_ADDED * Remove extraction of unused fields from IbmWatsonxRerankServiceSettings * Add space * Add space --------- Co-authored-by: elasticsearchmachine <infra-root+elasticsearchmachine@elastic.co> (cherry picked from commit 09b1c6d)
- Loading branch information
1 parent
2b7d91a
commit d2b5f05
Showing
21 changed files
with
1,370 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
pr: 117176 | ||
summary: Integrate IBM watsonx to Inference API for re-ranking task | ||
area: Experiences | ||
type: enhancement | ||
issues: [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
72 changes: 72 additions & 0 deletions
72
...rg/elasticsearch/xpack/inference/external/http/sender/IbmWatsonxRerankRequestManager.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.inference.external.http.sender; | ||
|
||
import org.apache.logging.log4j.LogManager; | ||
import org.apache.logging.log4j.Logger; | ||
import org.elasticsearch.action.ActionListener; | ||
import org.elasticsearch.inference.InferenceServiceResults; | ||
import org.elasticsearch.threadpool.ThreadPool; | ||
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; | ||
import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; | ||
import org.elasticsearch.xpack.inference.external.ibmwatsonx.IbmWatsonxResponseHandler; | ||
import org.elasticsearch.xpack.inference.external.request.ibmwatsonx.IbmWatsonxRerankRequest; | ||
import org.elasticsearch.xpack.inference.external.response.ibmwatsonx.IbmWatsonxRankedResponseEntity; | ||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel; | ||
|
||
import java.util.List; | ||
import java.util.Objects; | ||
import java.util.function.Supplier; | ||
|
||
public class IbmWatsonxRerankRequestManager extends IbmWatsonxRequestManager { | ||
private static final Logger logger = LogManager.getLogger(IbmWatsonxRerankRequestManager.class); | ||
private static final ResponseHandler HANDLER = createIbmWatsonxResponseHandler(); | ||
|
||
private static ResponseHandler createIbmWatsonxResponseHandler() { | ||
return new IbmWatsonxResponseHandler( | ||
"ibm watsonx rerank", | ||
(request, response) -> IbmWatsonxRankedResponseEntity.fromResponse(response) | ||
); | ||
} | ||
|
||
public static IbmWatsonxRerankRequestManager of(IbmWatsonxRerankModel model, ThreadPool threadPool) { | ||
return new IbmWatsonxRerankRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); | ||
} | ||
|
||
private final IbmWatsonxRerankModel model; | ||
|
||
public IbmWatsonxRerankRequestManager(IbmWatsonxRerankModel model, ThreadPool threadPool) { | ||
super(threadPool, model); | ||
this.model = model; | ||
} | ||
|
||
@Override | ||
public void execute( | ||
InferenceInputs inferenceInputs, | ||
RequestSender requestSender, | ||
Supplier<Boolean> hasRequestCompletedFunction, | ||
ActionListener<InferenceServiceResults> listener | ||
) { | ||
var rerankInput = QueryAndDocsInputs.of(inferenceInputs); | ||
|
||
execute( | ||
new ExecutableInferenceRequest( | ||
requestSender, | ||
logger, | ||
getRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model), | ||
HANDLER, | ||
hasRequestCompletedFunction, | ||
listener | ||
) | ||
); | ||
} | ||
|
||
protected IbmWatsonxRerankRequest getRerankRequest(String query, List<String> chunks, IbmWatsonxRerankModel model) { | ||
return new IbmWatsonxRerankRequest(query, chunks, model); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
110 changes: 110 additions & 0 deletions
110
...rg/elasticsearch/xpack/inference/external/request/ibmwatsonx/IbmWatsonxRerankRequest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.inference.external.request.ibmwatsonx; | ||
|
||
import org.apache.http.HttpHeaders; | ||
import org.apache.http.client.methods.HttpPost; | ||
import org.apache.http.entity.ByteArrayEntity; | ||
import org.elasticsearch.common.Strings; | ||
import org.elasticsearch.xcontent.XContentType; | ||
import org.elasticsearch.xpack.inference.external.request.HttpRequest; | ||
import org.elasticsearch.xpack.inference.external.request.Request; | ||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankModel; | ||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings; | ||
|
||
import java.net.URI; | ||
import java.net.URISyntaxException; | ||
import java.nio.charset.StandardCharsets; | ||
import java.util.List; | ||
import java.util.Objects; | ||
|
||
public class IbmWatsonxRerankRequest implements IbmWatsonxRequest { | ||
|
||
private final String query; | ||
private final List<String> input; | ||
private final IbmWatsonxRerankTaskSettings taskSettings; | ||
private final IbmWatsonxRerankModel model; | ||
|
||
public IbmWatsonxRerankRequest(String query, List<String> input, IbmWatsonxRerankModel model) { | ||
Objects.requireNonNull(model); | ||
|
||
this.input = Objects.requireNonNull(input); | ||
this.query = Objects.requireNonNull(query); | ||
taskSettings = model.getTaskSettings(); | ||
this.model = model; | ||
} | ||
|
||
@Override | ||
public HttpRequest createHttpRequest() { | ||
URI uri; | ||
|
||
try { | ||
uri = new URI(model.uri().toString()); | ||
} catch (URISyntaxException ex) { | ||
throw new IllegalArgumentException("cannot parse URI patter"); | ||
} | ||
|
||
HttpPost httpPost = new HttpPost(uri); | ||
|
||
ByteArrayEntity byteEntity = new ByteArrayEntity( | ||
Strings.toString( | ||
new IbmWatsonxRerankRequestEntity( | ||
query, | ||
input, | ||
taskSettings, | ||
model.getServiceSettings().modelId(), | ||
model.getServiceSettings().projectId() | ||
) | ||
).getBytes(StandardCharsets.UTF_8) | ||
); | ||
|
||
httpPost.setEntity(byteEntity); | ||
httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); | ||
|
||
decorateWithAuth(httpPost); | ||
|
||
return new HttpRequest(httpPost, getInferenceEntityId()); | ||
} | ||
|
||
public void decorateWithAuth(HttpPost httpPost) { | ||
IbmWatsonxRequest.decorateWithBearerToken(httpPost, model.getSecretSettings(), model.getInferenceEntityId()); | ||
} | ||
|
||
@Override | ||
public String getInferenceEntityId() { | ||
return model.getInferenceEntityId(); | ||
} | ||
|
||
@Override | ||
public URI getURI() { | ||
return model.uri(); | ||
} | ||
|
||
@Override | ||
public Request truncate() { | ||
return this; | ||
} | ||
|
||
public String getQuery() { | ||
return query; | ||
} | ||
|
||
public List<String> getInput() { | ||
return input; | ||
} | ||
|
||
public IbmWatsonxRerankModel getModel() { | ||
return model; | ||
} | ||
|
||
@Override | ||
public boolean[] getTruncationInfo() { | ||
return null; | ||
} | ||
|
||
} |
77 changes: 77 additions & 0 deletions
77
...sticsearch/xpack/inference/external/request/ibmwatsonx/IbmWatsonxRerankRequestEntity.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
/* | ||
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one | ||
* or more contributor license agreements. Licensed under the Elastic License | ||
* 2.0; you may not use this file except in compliance with the Elastic License | ||
* 2.0. | ||
*/ | ||
|
||
package org.elasticsearch.xpack.inference.external.request.ibmwatsonx; | ||
|
||
import org.elasticsearch.xcontent.ToXContentObject; | ||
import org.elasticsearch.xcontent.XContentBuilder; | ||
import org.elasticsearch.xpack.inference.services.ibmwatsonx.rerank.IbmWatsonxRerankTaskSettings; | ||
|
||
import java.io.IOException; | ||
import java.util.List; | ||
import java.util.Objects; | ||
|
||
public record IbmWatsonxRerankRequestEntity( | ||
String query, | ||
List<String> inputs, | ||
IbmWatsonxRerankTaskSettings taskSettings, | ||
String modelId, | ||
String projectId | ||
) implements ToXContentObject { | ||
|
||
private static final String INPUTS_FIELD = "inputs"; | ||
private static final String QUERY_FIELD = "query"; | ||
private static final String MODEL_ID_FIELD = "model_id"; | ||
private static final String PROJECT_ID_FIELD = "project_id"; | ||
|
||
public IbmWatsonxRerankRequestEntity { | ||
Objects.requireNonNull(query); | ||
Objects.requireNonNull(inputs); | ||
Objects.requireNonNull(modelId); | ||
Objects.requireNonNull(projectId); | ||
Objects.requireNonNull(taskSettings); | ||
} | ||
|
||
@Override | ||
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { | ||
builder.startObject(); | ||
|
||
builder.field(MODEL_ID_FIELD, modelId); | ||
builder.field(QUERY_FIELD, query); | ||
builder.startArray(INPUTS_FIELD); | ||
for (String input : inputs) { | ||
builder.startObject(); | ||
builder.field("text", input); | ||
builder.endObject(); | ||
} | ||
builder.endArray(); | ||
builder.field(PROJECT_ID_FIELD, projectId); | ||
|
||
builder.startObject("parameters"); | ||
{ | ||
if (taskSettings.getTruncateInputTokens() != null) { | ||
builder.field("truncate_input_tokens", taskSettings.getTruncateInputTokens()); | ||
} | ||
|
||
builder.startObject("return_options"); | ||
{ | ||
if (taskSettings.getDoesReturnDocuments() != null) { | ||
builder.field("inputs", taskSettings.getDoesReturnDocuments()); | ||
} | ||
if (taskSettings.getTopNDocumentsOnly() != null) { | ||
builder.field("top_n", taskSettings.getTopNDocumentsOnly()); | ||
} | ||
} | ||
builder.endObject(); | ||
} | ||
builder.endObject(); | ||
|
||
builder.endObject(); | ||
|
||
return builder; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.