Skip to content

[ML] Fixing bug with TransportPutModelAction listener and adding timeout to request #126805

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/changelog/126805.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pr: 126805
summary: Fixing bug with `TransportPutModelAction` listener and adding timeout to
request
area: Machine Learning
type: bug
issues: []
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ static TransportVersion def(int id) {
public static final TransportVersion REMOTE_EXCEPTION_8_19 = def(8_841_0_16);
public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS_8_19 = def(8_841_0_17);
public static final TransportVersion BATCHED_QUERY_PHASE_VERSION_BACKPORT_8_X = def(8_841_0_19);
public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19 = def(8_841_0_20);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0 = def(9_000_0_00);
public static final TransportVersion REMOVE_SNAPSHOT_FAILURES_90 = def(9_000_0_01);
public static final TransportVersion TRANSPORT_STATS_HANDLING_TIME_REQUIRED_90 = def(9_000_0_02);
Expand Down Expand Up @@ -222,6 +223,7 @@ static TransportVersion def(int id) {
public static final TransportVersion AMAZON_BEDROCK_TASK_SETTINGS = def(9_049_0_00);
public static final TransportVersion ESQL_REPORT_SHARD_PARTITIONING = def(9_050_0_00);
public static final TransportVersion ESQL_QUERY_PLANNING_DURATION = def(9_051_0_00);
public static final TransportVersion INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT = def(9_052_00_0);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

package org.elasticsearch.xpack.core.inference.action;

import org.elasticsearch.TransportVersions;
import org.elasticsearch.action.ActionRequestValidationException;
import org.elasticsearch.action.ActionResponse;
import org.elasticsearch.action.ActionType;
Expand All @@ -15,6 +16,7 @@
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.inference.ModelConfigurations;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.xcontent.ToXContentObject;
Expand All @@ -41,13 +43,15 @@ public static class Request extends AcknowledgedRequest<Request> {
private final String inferenceEntityId;
private final BytesReference content;
private final XContentType contentType;
private final TimeValue timeout;

public Request(TaskType taskType, String inferenceEntityId, BytesReference content, XContentType contentType) {
public Request(TaskType taskType, String inferenceEntityId, BytesReference content, XContentType contentType, TimeValue timeout) {
super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT);
this.taskType = taskType;
this.inferenceEntityId = inferenceEntityId;
this.content = content;
this.contentType = contentType;
this.timeout = timeout;
}

public Request(StreamInput in) throws IOException {
Expand All @@ -56,6 +60,13 @@ public Request(StreamInput in) throws IOException {
this.taskType = TaskType.fromStream(in);
this.content = in.readBytesReference();
this.contentType = in.readEnum(XContentType.class);

if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT)
|| in.getTransportVersion().isPatchFrom(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) {
this.timeout = in.readTimeValue();
} else {
this.timeout = InferenceAction.Request.DEFAULT_TIMEOUT;
}
}

public TaskType getTaskType() {
Expand All @@ -74,13 +85,22 @@ public XContentType getContentType() {
return contentType;
}

public TimeValue getTimeout() {
return timeout;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(inferenceEntityId);
taskType.writeTo(out);
out.writeBytesReference(content);
XContentHelper.writeTo(out, contentType);

if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT)
|| out.getTransportVersion().isPatchFrom(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) {
out.writeTimeValue(timeout);
}
}

@Override
Expand All @@ -105,12 +125,13 @@ public boolean equals(Object o) {
return taskType == request.taskType
&& Objects.equals(inferenceEntityId, request.inferenceEntityId)
&& Objects.equals(content, request.content)
&& contentType == request.contentType;
&& contentType == request.contentType
&& Objects.equals(timeout, request.timeout);
}

@Override
public int hashCode() {
return Objects.hash(taskType, inferenceEntityId, content, contentType);
return Objects.hash(taskType, inferenceEntityId, content, contentType, timeout);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,26 +34,45 @@ public void setup() throws Exception {

public void testValidate() {
// valid model ID
var request = new PutInferenceModelAction.Request(TASK_TYPE, MODEL_ID + "_-0", BYTES, X_CONTENT_TYPE);
var request = new PutInferenceModelAction.Request(
TASK_TYPE,
MODEL_ID + "_-0",
BYTES,
X_CONTENT_TYPE,
InferenceAction.Request.DEFAULT_TIMEOUT
);
ActionRequestValidationException validationException = request.validate();
assertNull(validationException);

// invalid model IDs

var invalidRequest = new PutInferenceModelAction.Request(TASK_TYPE, "", BYTES, X_CONTENT_TYPE);
var invalidRequest = new PutInferenceModelAction.Request(
TASK_TYPE,
"",
BYTES,
X_CONTENT_TYPE,
InferenceAction.Request.DEFAULT_TIMEOUT
);
validationException = invalidRequest.validate();
assertNotNull(validationException);

var invalidRequest2 = new PutInferenceModelAction.Request(
TASK_TYPE,
randomAlphaOfLengthBetween(1, 10) + randomFrom(MlStringsTests.SOME_INVALID_CHARS),
BYTES,
X_CONTENT_TYPE
X_CONTENT_TYPE,
InferenceAction.Request.DEFAULT_TIMEOUT
);
validationException = invalidRequest2.validate();
assertNotNull(validationException);

var invalidRequest3 = new PutInferenceModelAction.Request(TASK_TYPE, null, BYTES, X_CONTENT_TYPE);
var invalidRequest3 = new PutInferenceModelAction.Request(
TASK_TYPE,
null,
BYTES,
X_CONTENT_TYPE,
InferenceAction.Request.DEFAULT_TIMEOUT
);
validationException = invalidRequest3.validate();
assertNotNull(validationException);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ protected void masterOperation(
return;
}

parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.ackTimeout(), listener);
parseAndStoreModel(service.get(), request.getInferenceEntityId(), resolvedTaskType, requestAsMap, request.getTimeout(), listener);
}

private void parseAndStoreModel(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.List;

import static org.elasticsearch.rest.RestRequest.Method.PUT;
import static org.elasticsearch.xpack.inference.rest.BaseInferenceAction.parseTimeout;
import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID;
import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID_PATH;
import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_INFERENCE_ID_PATH;
Expand Down Expand Up @@ -49,8 +50,15 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient
taskType = TaskType.ANY; // task type must be defined in the body
}

var inferTimeout = parseTimeout(restRequest);
var content = restRequest.requiredContent();
var request = new PutInferenceModelAction.Request(taskType, inferenceEntityId, content, restRequest.getXContentType());
var request = new PutInferenceModelAction.Request(
taskType,
inferenceEntityId,
content,
restRequest.getXContentType(),
inferTimeout
);
return channel -> client.execute(
PutInferenceModelAction.INSTANCE,
request,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ public void start(Model model, TimeValue timeout, ActionListener<Boolean> finalL
})
.<Boolean>andThen((l2, modelDidPut) -> {
var startRequest = esModel.getStartTrainedModelDeploymentActionRequest(timeout);
var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, finalListener);
var responseListener = esModel.getCreateTrainedModelAssignmentActionListener(model, l2);
client.execute(StartTrainedModelDeploymentAction.INSTANCE, startRequest, responseListener);
})
.addListener(finalListener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ public void onFailure(Exception e) {
&& statusException.getRootCause() instanceof ResourceAlreadyExistsException) {
// Deployment is already started
listener.onResponse(Boolean.TRUE);
} else {
listener.onFailure(e);
}
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@

package org.elasticsearch.xpack.inference.action;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.TransportVersions;
import org.elasticsearch.common.io.stream.Writeable;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.AbstractWireSerializingTestCase;
import org.elasticsearch.xcontent.XContentType;
import org.elasticsearch.xpack.core.inference.action.InferenceAction;
import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction;
import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase;

public class PutInferenceModelRequestTests extends AbstractWireSerializingTestCase<PutInferenceModelAction.Request> {
public class PutInferenceModelRequestTests extends AbstractBWCWireSerializationTestCase<PutInferenceModelAction.Request> {
@Override
protected Writeable.Reader<PutInferenceModelAction.Request> instanceReader() {
return PutInferenceModelAction.Request::new;
Expand All @@ -25,38 +28,29 @@ protected PutInferenceModelAction.Request createTestInstance() {
randomFrom(TaskType.values()),
randomAlphaOfLength(6),
randomBytesReference(50),
randomFrom(XContentType.values())
randomFrom(XContentType.values()),
randomTimeValue()
);
}

@Override
protected PutInferenceModelAction.Request mutateInstance(PutInferenceModelAction.Request instance) {
return switch (randomIntBetween(0, 3)) {
case 0 -> new PutInferenceModelAction.Request(
TaskType.values()[(instance.getTaskType().ordinal() + 1) % TaskType.values().length],
instance.getInferenceEntityId(),
instance.getContent(),
instance.getContentType()
);
case 1 -> new PutInferenceModelAction.Request(
instance.getTaskType(),
instance.getInferenceEntityId() + "foo",
instance.getContent(),
instance.getContentType()
);
case 2 -> new PutInferenceModelAction.Request(
instance.getTaskType(),
instance.getInferenceEntityId(),
randomBytesReference(instance.getContent().length() + 1),
instance.getContentType()
);
case 3 -> new PutInferenceModelAction.Request(
return randomValueOtherThan(instance, this::createTestInstance);
}

@Override
protected PutInferenceModelAction.Request mutateInstanceForVersion(PutInferenceModelAction.Request instance, TransportVersion version) {
if (version.onOrAfter(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT)
|| version.isPatchFrom(TransportVersions.INFERENCE_ADD_TIMEOUT_PUT_ENDPOINT_8_19)) {
return instance;
} else {
return new PutInferenceModelAction.Request(
instance.getTaskType(),
instance.getInferenceEntityId(),
instance.getContent(),
XContentType.values()[(instance.getContentType().ordinal() + 1) % XContentType.values().length]
instance.getContentType(),
InferenceAction.Request.DEFAULT_TIMEOUT
);
default -> throw new IllegalStateException();
};
}
}
}
Loading