Skip to content

Commit a4111ec

Browse files
[ML] Refactor inference request executor to leverage scheduled execution (#126858) (#126948)
* Using threadpool schedule and fixing tests * Update docs/changelog/126858.yaml * Clean up * change log (cherry picked from commit 7a0f63c) # Conflicts: # x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java
1 parent fbc135b commit a4111ec

File tree

4 files changed

+57
-74
lines changed

4 files changed

+57
-74
lines changed

docs/changelog/126858.yaml

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pr: 126858
2+
summary: Leverage threadpool schedule for inference api to avoid long running thread
3+
area: Machine Learning
4+
type: bug
5+
issues:
6+
- 126853

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java

+41-36
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,6 @@
5757
*/
5858
class RequestExecutorService implements RequestExecutor {
5959

60-
/**
61-
* Provides dependency injection mainly for testing
62-
*/
63-
interface Sleeper {
64-
void sleep(TimeValue sleepTime) throws InterruptedException;
65-
}
66-
67-
// default for tests
68-
static final Sleeper DEFAULT_SLEEPER = sleepTime -> sleepTime.timeUnit().sleep(sleepTime.duration());
6960
// default for tests
7061
static final AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> DEFAULT_QUEUE_CREATOR =
7162
new AdjustableCapacityBlockingQueue.QueueCreator<>() {
@@ -118,7 +109,6 @@ interface RateLimiterCreator {
118109
private final Clock clock;
119110
private final AtomicBoolean shutdown = new AtomicBoolean(false);
120111
private final AdjustableCapacityBlockingQueue.QueueCreator<RejectableTask> queueCreator;
121-
private final Sleeper sleeper;
122112
private final RateLimiterCreator rateLimiterCreator;
123113
private final AtomicReference<Scheduler.Cancellable> cancellableCleanupTask = new AtomicReference<>();
124114
private final AtomicBoolean started = new AtomicBoolean(false);
@@ -129,16 +119,7 @@ interface RateLimiterCreator {
129119
RequestExecutorServiceSettings settings,
130120
RequestSender requestSender
131121
) {
132-
this(
133-
threadPool,
134-
DEFAULT_QUEUE_CREATOR,
135-
startupLatch,
136-
settings,
137-
requestSender,
138-
Clock.systemUTC(),
139-
DEFAULT_SLEEPER,
140-
DEFAULT_RATE_LIMIT_CREATOR
141-
);
122+
this(threadPool, DEFAULT_QUEUE_CREATOR, startupLatch, settings, requestSender, Clock.systemUTC(), DEFAULT_RATE_LIMIT_CREATOR);
142123
}
143124

144125
RequestExecutorService(
@@ -148,7 +129,6 @@ interface RateLimiterCreator {
148129
RequestExecutorServiceSettings settings,
149130
RequestSender requestSender,
150131
Clock clock,
151-
Sleeper sleeper,
152132
RateLimiterCreator rateLimiterCreator
153133
) {
154134
this.threadPool = Objects.requireNonNull(threadPool);
@@ -157,7 +137,6 @@ interface RateLimiterCreator {
157137
this.requestSender = Objects.requireNonNull(requestSender);
158138
this.settings = Objects.requireNonNull(settings);
159139
this.clock = Objects.requireNonNull(clock);
160-
this.sleeper = Objects.requireNonNull(sleeper);
161140
this.rateLimiterCreator = Objects.requireNonNull(rateLimiterCreator);
162141
}
163142

@@ -213,15 +192,10 @@ public void start() {
213192
startCleanupTask();
214193
signalStartInitiated();
215194

216-
while (isShutdown() == false) {
217-
handleTasks();
218-
}
219-
} catch (InterruptedException e) {
220-
Thread.currentThread().interrupt();
221-
} finally {
222-
shutdown();
223-
notifyRequestsOfShutdown();
224-
terminationLatch.countDown();
195+
handleTasks();
196+
} catch (Exception e) {
197+
logger.warn("Failed to start request executor", e);
198+
cleanup();
225199
}
226200
}
227201

@@ -256,13 +230,44 @@ void removeStaleGroupings() {
256230
}
257231
}
258232

259-
private void handleTasks() throws InterruptedException {
260-
var timeToWait = settings.getTaskPollFrequency();
261-
for (var endpoint : rateLimitGroupings.values()) {
262-
timeToWait = TimeValue.min(endpoint.executeEnqueuedTask(), timeToWait);
233+
private void scheduleNextHandleTasks(TimeValue timeToWait) {
234+
if (shutdown.get()) {
235+
logger.debug("Shutdown requested while scheduling next handle task call, cleaning up");
236+
cleanup();
237+
return;
238+
}
239+
240+
threadPool.schedule(this::handleTasks, timeToWait, threadPool.executor(UTILITY_THREAD_POOL_NAME));
241+
}
242+
243+
private void cleanup() {
244+
try {
245+
shutdown();
246+
notifyRequestsOfShutdown();
247+
terminationLatch.countDown();
248+
} catch (Exception e) {
249+
logger.warn("Encountered an error while cleaning up", e);
263250
}
251+
}
264252

265-
sleeper.sleep(timeToWait);
253+
private void handleTasks() {
254+
try {
255+
if (shutdown.get()) {
256+
logger.debug("Shutdown requested while handling tasks, cleaning up");
257+
cleanup();
258+
return;
259+
}
260+
261+
var timeToWait = settings.getTaskPollFrequency();
262+
for (var endpoint : rateLimitGroupings.values()) {
263+
timeToWait = TimeValue.min(endpoint.executeEnqueuedTask(), timeToWait);
264+
}
265+
266+
scheduleNextHandleTasks(timeToWait);
267+
} catch (Exception e) {
268+
logger.warn("Encountered an error while handling tasks", e);
269+
cleanup();
270+
}
266271
}
267272

268273
private void notifyRequestsOfShutdown() {

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl;
5151
import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER;
5252
import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat;
53+
import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings;
5354
import static org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceService.ELASTIC_INFERENCE_SERVICE_IDENTIFIER;
5455
import static org.hamcrest.Matchers.equalTo;
5556
import static org.hamcrest.Matchers.hasSize;
@@ -88,7 +89,7 @@ public void shutdown() throws IOException, InterruptedException {
8889
}
8990

9091
public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception {
91-
var senderFactory = createSenderFactory(clientManager, threadRef);
92+
var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
9293

9394
try (var sender = createSender(senderFactory)) {
9495
sender.start();

x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java

+8-37
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
import static org.mockito.ArgumentMatchers.any;
5151
import static org.mockito.ArgumentMatchers.anyInt;
5252
import static org.mockito.Mockito.doAnswer;
53-
import static org.mockito.Mockito.doThrow;
5453
import static org.mockito.Mockito.mock;
5554
import static org.mockito.Mockito.times;
5655
import static org.mockito.Mockito.verify;
@@ -195,7 +194,7 @@ public void testExecute_Throws_WhenQueueIsFull() {
195194
assertFalse(thrownException.isExecutorShutdown());
196195
}
197196

198-
public void testTaskThrowsError_CallsOnFailure() {
197+
public void testTaskThrowsError_CallsOnFailure() throws InterruptedException {
199198
var requestSender = mock(RetryingHttpSender.class);
200199

201200
var service = createRequestExecutorService(null, requestSender);
@@ -218,6 +217,8 @@ public void testTaskThrowsError_CallsOnFailure() {
218217
var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT));
219218
assertThat(thrownException.getMessage(), is(format("Failed to send request from inference entity id [%s]", "id")));
220219
assertThat(thrownException.getCause(), instanceOf(IllegalArgumentException.class));
220+
service.awaitTermination(TIMEOUT.getSeconds(), TimeUnit.SECONDS);
221+
221222
assertTrue(service.isTerminated());
222223
}
223224

@@ -340,7 +341,6 @@ public void testQueuePoll_DoesNotCauseServiceToTerminate_WhenItThrows() throws I
340341
createRequestExecutorServiceSettingsEmpty(),
341342
requestSender,
342343
Clock.systemUTC(),
343-
RequestExecutorService.DEFAULT_SLEEPER,
344344
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
345345
);
346346

@@ -354,36 +354,7 @@ public void testQueuePoll_DoesNotCauseServiceToTerminate_WhenItThrows() throws I
354354
});
355355
service.start();
356356

357-
assertTrue(service.isTerminated());
358-
}
359-
360-
public void testSleep_ThrowingInterruptedException_TerminatesService() throws Exception {
361-
@SuppressWarnings("unchecked")
362-
BlockingQueue<RejectableTask> queue = mock(LinkedBlockingQueue.class);
363-
var sleeper = mock(RequestExecutorService.Sleeper.class);
364-
doThrow(new InterruptedException("failed")).when(sleeper).sleep(any());
365-
366-
var service = new RequestExecutorService(
367-
threadPool,
368-
mockQueueCreator(queue),
369-
null,
370-
createRequestExecutorServiceSettingsEmpty(),
371-
mock(RetryingHttpSender.class),
372-
Clock.systemUTC(),
373-
sleeper,
374-
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
375-
);
376-
377-
Future<?> executorTermination = threadPool.generic().submit(() -> {
378-
try {
379-
service.start();
380-
} catch (Exception e) {
381-
fail(Strings.format("Failed to shutdown executor: %s", e));
382-
}
383-
});
384-
385-
executorTermination.get(TIMEOUT.millis(), TimeUnit.MILLISECONDS);
386-
357+
service.awaitTermination(TIMEOUT.getSeconds(), TimeUnit.SECONDS);
387358
assertTrue(service.isTerminated());
388359
}
389360

@@ -550,7 +521,6 @@ public void testDoesNotExecuteTask_WhenCannotReserveTokens() {
550521
settings,
551522
requestSender,
552523
Clock.systemUTC(),
553-
RequestExecutorService.DEFAULT_SLEEPER,
554524
rateLimiterCreator
555525
);
556526
var requestManager = RequestManagerTests.createMock(requestSender);
@@ -583,7 +553,6 @@ public void testDoesNotExecuteTask_WhenCannotReserveTokens_AndThenCanReserve_And
583553
settings,
584554
requestSender,
585555
Clock.systemUTC(),
586-
RequestExecutorService.DEFAULT_SLEEPER,
587556
rateLimiterCreator
588557
);
589558
var requestManager = RequestManagerTests.createMock(requestSender);
@@ -595,11 +564,15 @@ public void testDoesNotExecuteTask_WhenCannotReserveTokens_AndThenCanReserve_And
595564

596565
doAnswer(invocation -> {
597566
service.shutdown();
567+
ActionListener<InferenceServiceResults> passedListener = invocation.getArgument(4);
568+
passedListener.onResponse(null);
569+
598570
return Void.TYPE;
599571
}).when(requestSender).send(any(), any(), any(), any(), any());
600572

601573
service.start();
602574

575+
listener.actionGet(TIMEOUT);
603576
verify(requestSender, times(1)).send(any(), any(), any(), any(), any());
604577
}
605578

@@ -617,7 +590,6 @@ public void testRemovesRateLimitGroup_AfterStaleDuration() {
617590
settings,
618591
requestSender,
619592
clock,
620-
RequestExecutorService.DEFAULT_SLEEPER,
621593
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
622594
);
623595
var requestManager = RequestManagerTests.createMock(requestSender, "id1");
@@ -651,7 +623,6 @@ public void testStartsCleanupThread() {
651623
settings,
652624
requestSender,
653625
Clock.systemUTC(),
654-
RequestExecutorService.DEFAULT_SLEEPER,
655626
RequestExecutorService.DEFAULT_RATE_LIMIT_CREATOR
656627
);
657628

0 commit comments

Comments
 (0)