diff --git a/docs/changelog/126805.yaml b/docs/changelog/126805.yaml new file mode 100644 index 0000000000000..ee9a4be7e4fd5 --- /dev/null +++ b/docs/changelog/126805.yaml @@ -0,0 +1,6 @@ +pr: 126805 +summary: Fixing bug with `TransportPutModelAction` listener and adding timeout to + request +area: Machine Learning +type: bug +issues: [] diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 51482a99dc8b1..1c400875554eb 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -160,6 +160,7 @@ static TransportVersion def(int id) { public static final TransportVersion REMOTE_EXCEPTION_8_19 = def(8_841_0_16); public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS_8_19 = def(8_841_0_17); public static final TransportVersion BATCHED_QUERY_PHASE_VERSION_BACKPORT_8_X = def(8_841_0_19); + public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19 = def(8_841_0_20); public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00); public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01); public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02); @@ -223,6 +224,7 @@ static TransportVersion def(int id) { public static final TransportVersion ESQL_REPORT_SHARD_PARTITIONING = def(9_050_0_00); public static final TransportVersion ESQL_QUERY_PLANNING_DURATION = def(9_051_0_00); public static final TransportVersion ESQL_DOCUMENTS_FOUND_AND_VALUES_LOADED = def(9_052_0_00); + public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT = def(9_053_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelAction.java index a7f65c60a06c4..cded88c36388c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelAction.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.inference.action; +import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionType; @@ -15,6 +16,7 @@ import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.ToXContentObject; @@ -41,13 +43,15 @@ public static class Request extends AcknowledgedRequest { private final String inferenceEntityId; private final BytesReference content; private final XContentType contentType; + private final TimeValue timeout; - public Request(TaskType taskType, String inferenceEntityId, BytesReference content, XContentType contentType) { + public Request(TaskType taskType, String inferenceEntityId, BytesReference content, XContentType contentType, TimeValue timeout) { super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT); this.taskType = taskType; this.inferenceEntityId = inferenceEntityId; this.content = content; this.contentType = contentType; + this.timeout = timeout; } public Request(StreamInput in) throws IOException { @@ -56,6 +60,13 @@ public Request(StreamInput in) throws IOException { this.taskType = TaskType.fromStream(in); this.content = in.readBytesReference(); this.contentType = in.readEnum(XContentType.class); + + if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT) + || in.getTransportVersion().isPatchFrom(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) { + this.timeout = in.readTimeValue(); + } else { + this.timeout = InferenceAction.Request.DEFAULT_TIMEOUT; + } } public TaskType getTaskType() { @@ -74,6 +85,10 @@ public XContentType getContentType() { return contentType; } + public TimeValue getTimeout() { + return timeout; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); @@ -81,6 +96,11 @@ public void writeTo(StreamOutput out) throws IOException { taskType.writeTo(out); out.writeBytesReference(content); XContentHelper.writeTo(out, contentType); + + if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT) + || out.getTransportVersion().isPatchFrom(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) { + out.writeTimeValue(timeout); + } } @Override @@ -105,12 +125,13 @@ public boolean equals(Object o) { return taskType == request.taskType && Objects.equals(inferenceEntityId, request.inferenceEntityId) && Objects.equals(content, request.content) - && contentType == request.contentType; + && contentType == request.contentType + && Objects.equals(timeout, request.timeout); } @Override public int hashCode() { - return Objects.hash(taskType, inferenceEntityId, content, contentType); + return Objects.hash(taskType, inferenceEntityId, content, contentType, timeout); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelActionTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelActionTests.java index e0b04c6fe8769..f9f67167a12b1 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelActionTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/PutInferenceModelActionTests.java @@ -34,13 +34,25 @@ public void setup() throws Exception { public void testValidate() { // valid model ID - var request = new PutInferenceModelAction.Request(TASK_TYPE, MODEL_ID + "_-0", BYTES, X_CONTENT_TYPE); + var request = new PutInferenceModelAction.Request( + TASK_TYPE, + MODEL_ID + "_-0", + BYTES, + X_CONTENT_TYPE, + InferenceAction.Request.DEFAULT_TIMEOUT + ); ActionRequestValidationException validationException = request.validate(); assertNull(validationException); // invalid model IDs - var invalidRequest = new PutInferenceModelAction.Request(TASK_TYPE, "", BYTES, X_CONTENT_TYPE); + var invalidRequest = new PutInferenceModelAction.Request( + TASK_TYPE, + "", + BYTES, + X_CONTENT_TYPE, + InferenceAction.Request.DEFAULT_TIMEOUT + ); validationException = invalidRequest.validate(); assertNotNull(validationException); @@ -48,12 +60,19 @@ public void testValidate() { TASK_TYPE, randomAlphaOfLengthBetween(1, 10) + randomFrom(MlStringsTests.SOME_INVALID_CHARS), BYTES, - X_CONTENT_TYPE + X_CONTENT_TYPE, + InferenceAction.Request.DEFAULT_TIMEOUT ); validationException = invalidRequest2.validate(); assertNotNull(validationException); - var invalidRequest3 = new PutInferenceModelAction.Request(TASK_TYPE, null, BYTES, X_CONTENT_TYPE); + var invalidRequest3 = new PutInferenceModelAction.Request( + TASK_TYPE, + null, + BYTES, + X_CONTENT_TYPE, + InferenceAction.Request.DEFAULT_TIMEOUT + ); validationException = invalidRequest3.validate(); assertNotNull(validationException); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java index eeea8a28df486..bc9d87f43ada0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportPutInferenceModelAction.java @@ -177,7 +177,7 @@ protected void masterOperation( return; } - parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.ackTimeout(), listener); + parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.getTimeout(), listener); } private void parseAndStoreModel( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutInferenceModelAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutInferenceModelAction.java index 655e11996d522..838e6512d805f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutInferenceModelAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestPutInferenceModelAction.java @@ -20,6 +20,7 @@ import java.util.List; import static org.elasticsearch.rest.RestRequest.Method.PUT; +import static org.elasticsearch.xpack.inference.rest.BaseInferenceAction.parseTimeout; import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID; import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID_PATH; import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_INFERENCE_ID_PATH; @@ -49,8 +50,15 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient taskType = TaskType.ANY; // task type must be defined in the body } + var inferTimeout = parseTimeout(restRequest); var content = restRequest.requiredContent(); - var request = new PutInferenceModelAction.Request(taskType, inferenceEntityId, content, restRequest.getXContentType()); + var request = new PutInferenceModelAction.Request( + taskType, + inferenceEntityId, + content, + restRequest.getXContentType(), + inferTimeout + ); return channel -> client.execute( PutInferenceModelAction.INSTANCE, request, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java index f61398fcacacf..e514867780669 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/PutInferenceModelRequestTests.java @@ -7,13 +7,16 @@ package org.elasticsearch.xpack.inference.action; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.inference.TaskType; -import org.elasticsearch.test.AbstractWireSerializingTestCase; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; -public class PutInferenceModelRequestTests extends AbstractWireSerializingTestCase { +public class PutInferenceModelRequestTests extends AbstractBWCWireSerializationTestCase { @Override protected Writeable.Reader instanceReader() { return PutInferenceModelAction.Request::new; @@ -25,38 +28,29 @@ protected PutInferenceModelAction.Request createTestInstance() { randomFrom(TaskType.values()), randomAlphaOfLength(6), randomBytesReference(50), - randomFrom(XContentType.values()) + randomFrom(XContentType.values()), + randomTimeValue() ); } @Override protected PutInferenceModelAction.Request mutateInstance(PutInferenceModelAction.Request instance) { - return switch (randomIntBetween(0, 3)) { - case 0 -> new PutInferenceModelAction.Request( - TaskType.values()[(instance.getTaskType().ordinal() + 1) % TaskType.values().length], - instance.getInferenceEntityId(), - instance.getContent(), - instance.getContentType() - ); - case 1 -> new PutInferenceModelAction.Request( - instance.getTaskType(), - instance.getInferenceEntityId() + "foo", - instance.getContent(), - instance.getContentType() - ); - case 2 -> new PutInferenceModelAction.Request( - instance.getTaskType(), - instance.getInferenceEntityId(), - randomBytesReference(instance.getContent().length() + 1), - instance.getContentType() - ); - case 3 -> new PutInferenceModelAction.Request( + return randomValueOtherThan(instance, this::createTestInstance); + } + + @Override + protected PutInferenceModelAction.Request mutateInstanceForVersion(PutInferenceModelAction.Request instance, TransportVersion version) { + if (version.onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT) + || version.isPatchFrom(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) { + return instance; + } else { + return new PutInferenceModelAction.Request( instance.getTaskType(), instance.getInferenceEntityId(), instance.getContent(), - XContentType.values()[(instance.getContentType().ordinal() + 1) % XContentType.values().length] + instance.getContentType(), + InferenceAction.Request.DEFAULT_TIMEOUT ); - default -> throw new IllegalStateException(); - }; + } } }