Skip to content

Commit 77e56ce

Browse files
[ML] Remove error parsing functionality for custom service (#128778) (#129638)
* Remove error parsing class * Adding test for lack of error parsing logic * Adding transport version check * Wrapping string in try/catch and adding test (cherry picked from commit f096773) # Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java
1 parent 83dd930 commit 77e56ce

File tree

13 files changed

+161
-435
lines changed

13 files changed

+161
-435
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ static TransportVersion def(int id) {
242242
public static final TransportVersion NONE_CHUNKING_STRATEGY_8_19 = def(8_841_0_49);
243243
public static final TransportVersion IDP_CUSTOM_SAML_ATTRIBUTES_ALLOW_LIST_8_19 = def(8_841_0_50);
244244
public static final TransportVersion SETTINGS_IN_DATA_STREAMS_8_19 = def(8_841_0_51);
245+
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19 = def(8_841_0_52);
245246

246247
/*
247248
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManager.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ public static RateLimitGrouping of(CustomModel model) {
4141
}
4242
}
4343

44-
private static ResponseHandler createCustomHandler(CustomModel model) {
45-
return new CustomResponseHandler("custom model", CustomResponseEntity::fromResponse, model.getServiceSettings().getErrorParser());
44+
private static ResponseHandler createCustomHandler() {
45+
return new CustomResponseHandler("custom model", CustomResponseEntity::fromResponse);
4646
}
4747

4848
public static CustomRequestManager of(CustomModel model, ThreadPool threadPool) {
@@ -55,7 +55,7 @@ public static CustomRequestManager of(CustomModel model, ThreadPool threadPool)
5555
private CustomRequestManager(CustomModel model, ThreadPool threadPool) {
5656
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
5757
this.model = model;
58-
this.handler = createCustomHandler(model);
58+
this.handler = createCustomHandler();
5959
}
6060

6161
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomResponseHandler.java

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,34 @@
88
package org.elasticsearch.xpack.inference.services.custom;
99

1010
import org.elasticsearch.ElasticsearchStatusException;
11+
import org.elasticsearch.common.Strings;
1112
import org.elasticsearch.inference.InferenceServiceResults;
1213
import org.elasticsearch.rest.RestStatus;
1314
import org.elasticsearch.xpack.inference.external.http.HttpResult;
1415
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
16+
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
1517
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
1618
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
1719
import org.elasticsearch.xpack.inference.external.request.Request;
18-
import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser;
20+
21+
import java.nio.charset.StandardCharsets;
22+
import java.util.function.Function;
1923

2024
/**
2125
* Defines how to handle various response types returned from the custom integration.
2226
*/
2327
public class CustomResponseHandler extends BaseResponseHandler {
24-
public CustomResponseHandler(String requestType, ResponseParser parseFunction, ErrorResponseParser errorParser) {
25-
super(requestType, parseFunction, errorParser);
28+
// default for testing
29+
static final Function<HttpResult, ErrorResponse> ERROR_PARSER = (httpResult) -> {
30+
try {
31+
return new ErrorResponse(new String(httpResult.body(), StandardCharsets.UTF_8));
32+
} catch (Exception e) {
33+
return new ErrorResponse(Strings.format("Failed to parse error response body: %s", e.getMessage()));
34+
}
35+
};
36+
37+
public CustomResponseHandler(String requestType, ResponseParser parseFunction) {
38+
super(requestType, parseFunction, ERROR_PARSER);
2639
}
2740

2841
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,7 @@ private static CustomServiceSettings getCustomServiceSettings(CustomModel custom
249249
serviceSettings.getQueryParameters(),
250250
serviceSettings.getRequestContentString(),
251251
serviceSettings.getResponseJsonParser(),
252-
serviceSettings.rateLimitSettings(),
253-
serviceSettings.getErrorParser()
252+
serviceSettings.rateLimitSettings()
254253
);
255254
}
256255

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
2727
import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser;
2828
import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser;
29-
import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser;
3029
import org.elasticsearch.xpack.inference.services.custom.response.NoopResponseParser;
3130
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
3231
import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser;
@@ -59,7 +58,6 @@ public class CustomServiceSettings extends FilteredXContentObject implements Ser
5958
public static final String REQUEST = "request";
6059
public static final String RESPONSE = "response";
6160
public static final String JSON_PARSER = "json_parser";
62-
public static final String ERROR_PARSER = "error_parser";
6361

6462
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000);
6563
private static final String RESPONSE_SCOPE = String.join(".", ModelConfigurations.SERVICE_SETTINGS, RESPONSE);
@@ -100,15 +98,6 @@ public static CustomServiceSettings fromMap(
10098

10199
var responseJsonParser = extractResponseParser(taskType, jsonParserMap, validationException);
102100

103-
Map<String, Object> errorParserMap = extractRequiredMap(
104-
Objects.requireNonNullElse(responseParserMap, new HashMap<>()),
105-
ERROR_PARSER,
106-
RESPONSE_SCOPE,
107-
validationException
108-
);
109-
110-
var errorParser = ErrorResponseParser.fromMap(errorParserMap, RESPONSE_SCOPE, inferenceId, validationException);
111-
112101
RateLimitSettings rateLimitSettings = RateLimitSettings.of(
113102
map,
114103
DEFAULT_RATE_LIMIT_SETTINGS,
@@ -117,13 +106,12 @@ public static CustomServiceSettings fromMap(
117106
context
118107
);
119108

120-
if (responseParserMap == null || jsonParserMap == null || errorParserMap == null) {
109+
if (responseParserMap == null || jsonParserMap == null) {
121110
throw validationException;
122111
}
123112

124113
throwIfNotEmptyMap(jsonParserMap, JSON_PARSER, NAME);
125114
throwIfNotEmptyMap(responseParserMap, RESPONSE, NAME);
126-
throwIfNotEmptyMap(errorParserMap, ERROR_PARSER, NAME);
127115

128116
if (validationException.validationErrors().isEmpty() == false) {
129117
throw validationException;
@@ -136,8 +124,7 @@ public static CustomServiceSettings fromMap(
136124
queryParams,
137125
requestContentString,
138126
responseJsonParser,
139-
rateLimitSettings,
140-
errorParser
127+
rateLimitSettings
141128
);
142129
}
143130

@@ -209,7 +196,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
209196
private final String requestContentString;
210197
private final CustomResponseParser responseJsonParser;
211198
private final RateLimitSettings rateLimitSettings;
212-
private final ErrorResponseParser errorParser;
213199

214200
public CustomServiceSettings(
215201
TextEmbeddingSettings textEmbeddingSettings,
@@ -218,8 +204,7 @@ public CustomServiceSettings(
218204
@Nullable QueryParameters queryParameters,
219205
String requestContentString,
220206
CustomResponseParser responseJsonParser,
221-
@Nullable RateLimitSettings rateLimitSettings,
222-
ErrorResponseParser errorParser
207+
@Nullable RateLimitSettings rateLimitSettings
223208
) {
224209
this.textEmbeddingSettings = Objects.requireNonNull(textEmbeddingSettings);
225210
this.url = Objects.requireNonNull(url);
@@ -228,7 +213,6 @@ public CustomServiceSettings(
228213
this.requestContentString = Objects.requireNonNull(requestContentString);
229214
this.responseJsonParser = Objects.requireNonNull(responseJsonParser);
230215
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
231-
this.errorParser = Objects.requireNonNull(errorParser);
232216
}
233217

234218
public CustomServiceSettings(StreamInput in) throws IOException {
@@ -239,7 +223,11 @@ public CustomServiceSettings(StreamInput in) throws IOException {
239223
requestContentString = in.readString();
240224
responseJsonParser = in.readNamedWriteable(CustomResponseParser.class);
241225
rateLimitSettings = new RateLimitSettings(in);
242-
errorParser = new ErrorResponseParser(in);
226+
if (in.getTransportVersion().before(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19)) {
227+
// Read the error parsing fields for backwards compatibility
228+
in.readString();
229+
in.readString();
230+
}
243231
}
244232

245233
@Override
@@ -287,10 +275,6 @@ public CustomResponseParser getResponseJsonParser() {
287275
return responseJsonParser;
288276
}
289277

290-
public ErrorResponseParser getErrorParser() {
291-
return errorParser;
292-
}
293-
294278
@Override
295279
public RateLimitSettings rateLimitSettings() {
296280
return rateLimitSettings;
@@ -331,7 +315,6 @@ public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder
331315
builder.startObject(RESPONSE);
332316
{
333317
responseJsonParser.toXContent(builder, params);
334-
errorParser.toXContent(builder, params);
335318
}
336319
builder.endObject();
337320

@@ -359,7 +342,11 @@ public void writeTo(StreamOutput out) throws IOException {
359342
out.writeString(requestContentString);
360343
out.writeNamedWriteable(responseJsonParser);
361344
rateLimitSettings.writeTo(out);
362-
errorParser.writeTo(out);
345+
if (out.getTransportVersion().before(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19)) {
346+
// Write empty strings for backwards compatibility for the error parsing fields
347+
out.writeString("");
348+
out.writeString("");
349+
}
363350
}
364351

365352
@Override
@@ -373,8 +360,7 @@ public boolean equals(Object o) {
373360
&& Objects.equals(queryParameters, that.queryParameters)
374361
&& Objects.equals(requestContentString, that.requestContentString)
375362
&& Objects.equals(responseJsonParser, that.responseJsonParser)
376-
&& Objects.equals(rateLimitSettings, that.rateLimitSettings)
377-
&& Objects.equals(errorParser, that.errorParser);
363+
&& Objects.equals(rateLimitSettings, that.rateLimitSettings);
378364
}
379365

380366
@Override
@@ -386,8 +372,7 @@ public int hashCode() {
386372
queryParameters,
387373
requestContentString,
388374
responseJsonParser,
389-
rateLimitSettings,
390-
errorParser
375+
rateLimitSettings
391376
);
392377
}
393378

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/response/ErrorResponseParser.java

Lines changed: 0 additions & 128 deletions
This file was deleted.

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomModelTests.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import org.elasticsearch.inference.TaskType;
1616
import org.elasticsearch.test.ESTestCase;
1717
import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser;
18-
import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser;
1918
import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
2019
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
2120
import org.hamcrest.MatcherAssert;
@@ -120,8 +119,7 @@ public static CustomModel getTestModel(TaskType taskType, CustomResponseParser r
120119
QueryParameters.EMPTY,
121120
requestContentString,
122121
responseParser,
123-
new RateLimitSettings(10_000),
124-
new ErrorResponseParser("$.error.message", inferenceId)
122+
new RateLimitSettings(10_000)
125123
);
126124

127125
CustomTaskSettings taskSettings = new CustomTaskSettings(Map.of(taskSettingsKey, taskSettingsValue));

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomRequestManagerTests.java

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import org.elasticsearch.threadpool.ThreadPool;
1818
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
1919
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
20-
import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser;
2120
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
2221
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
2322
import org.junit.After;
@@ -64,8 +63,7 @@ public void testCreateRequest_ThrowsException_ForInvalidUrl() {
6463
null,
6564
requestContentString,
6665
new RerankResponseParser("$.result.score"),
67-
new RateLimitSettings(10_000),
68-
new ErrorResponseParser("$.error.message", inferenceId)
66+
new RateLimitSettings(10_000)
6967
);
7068

7169
var model = CustomModelTests.createModel(

0 commit comments

Comments
 (0)