Skip to content

Commit 6026da5

Browse files
authored
server : clean-up completed tasks from waiting list (ggml-org#9531)
ggml-ci
1 parent eca0fab commit 6026da5

File tree

1 file changed

+34
-6
lines changed

1 file changed

+34
-6
lines changed

examples/server/server.cpp

+34-6
Original file line numberDiff line numberDiff line change
@@ -531,26 +531,38 @@ struct server_response {
531531

532532
// add the id_task to the list of tasks waiting for response
533533
void add_waiting_task_id(int id_task) {
534-
SRV_DBG("waiting for task id = %d\n", id_task);
534+
SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size());
535535

536536
std::unique_lock<std::mutex> lock(mutex_results);
537537
waiting_task_ids.insert(id_task);
538538
}
539539

540540
void add_waiting_tasks(const std::vector<server_task> & tasks) {
541-
for (const auto & t : tasks) {
542-
add_waiting_task_id(t.id);
541+
std::unique_lock<std::mutex> lock(mutex_results);
542+
543+
for (const auto & task : tasks) {
544+
SRV_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size());
545+
waiting_task_ids.insert(task.id);
543546
}
544547
}
545548

546549
// when the request is finished, we can remove task associated with it
547550
void remove_waiting_task_id(int id_task) {
548-
SRV_DBG("task id = %d is done\n", id_task);
551+
SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
549552

550553
std::unique_lock<std::mutex> lock(mutex_results);
551554
waiting_task_ids.erase(id_task);
552555
}
553556

557+
void remove_waiting_task_ids(const std::unordered_set<int> & id_tasks) {
558+
std::unique_lock<std::mutex> lock(mutex_results);
559+
560+
for (const auto & id_task : id_tasks) {
561+
SRV_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size());
562+
waiting_task_ids.erase(id_task);
563+
}
564+
}
565+
554566
// This function blocks the thread until there is a response for one of the id_tasks
555567
server_task_result recv(const std::unordered_set<int> & id_tasks) {
556568
while (true) {
@@ -2774,6 +2786,8 @@ int main(int argc, char ** argv) {
27742786
}, [&](const json & error_data) {
27752787
res_error(res, error_data);
27762788
});
2789+
2790+
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
27772791
} else {
27782792
const auto chunked_content_provider = [task_ids, &ctx_server](size_t, httplib::DataSink & sink) {
27792793
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
@@ -2784,7 +2798,12 @@ int main(int argc, char ** argv) {
27842798
sink.done();
27852799
return false;
27862800
};
2787-
res.set_chunked_content_provider("text/event-stream", chunked_content_provider);
2801+
2802+
auto on_complete = [task_ids, &ctx_server] (bool) {
2803+
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
2804+
};
2805+
2806+
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
27882807
}
27892808
};
27902809

@@ -2823,6 +2842,8 @@ int main(int argc, char ** argv) {
28232842
}, [&](const json & error_data) {
28242843
res_error(res, error_data);
28252844
});
2845+
2846+
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
28262847
} else {
28272848
const auto chunked_content_provider = [task_ids, &ctx_server, completion_id](size_t, httplib::DataSink & sink) {
28282849
ctx_server.receive_cmpl_results_stream(task_ids, [&](const server_task_result & result) -> bool {
@@ -2844,7 +2865,12 @@ int main(int argc, char ** argv) {
28442865
sink.done();
28452866
return true;
28462867
};
2847-
res.set_chunked_content_provider("text/event-stream", chunked_content_provider);
2868+
2869+
auto on_complete = [task_ids, &ctx_server] (bool) {
2870+
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
2871+
};
2872+
2873+
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
28482874
}
28492875
};
28502876

@@ -2953,6 +2979,8 @@ int main(int argc, char ** argv) {
29532979
res_error(res, error_data);
29542980
error = true;
29552981
});
2982+
2983+
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
29562984
}
29572985

29582986
if (error) {

0 commit comments

Comments
 (0)