mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-14 11:07:10 +00:00
server: (refactor) implement generator-based API for task results (#17174)
* server: (refactor) implement generator-based API for task results * improve * moving some code * fix "Response ended prematurely" * add sink.done before return false * rm redundant check * rm unused var * rename generator --> reader
This commit is contained in:
@@ -684,7 +684,7 @@ struct server_task_result {
|
|||||||
}
|
}
|
||||||
virtual bool is_stop() {
|
virtual bool is_stop() {
|
||||||
// only used by server_task_result_cmpl_*
|
// only used by server_task_result_cmpl_*
|
||||||
return false;
|
return true;
|
||||||
}
|
}
|
||||||
virtual int get_index() {
|
virtual int get_index() {
|
||||||
return -1;
|
return -1;
|
||||||
@@ -3238,105 +3238,6 @@ struct server_context {
|
|||||||
queue_results.send(std::move(res));
|
queue_results.send(std::move(res));
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
|
||||||
// Functions to create new task(s) and receive result(s)
|
|
||||||
//
|
|
||||||
|
|
||||||
void cancel_tasks(const std::unordered_set<int> & id_tasks) {
|
|
||||||
std::vector<server_task> cancel_tasks;
|
|
||||||
cancel_tasks.reserve(id_tasks.size());
|
|
||||||
for (const auto & id_task : id_tasks) {
|
|
||||||
SRV_WRN("cancel task, id_task = %d\n", id_task);
|
|
||||||
|
|
||||||
server_task task(SERVER_TASK_TYPE_CANCEL);
|
|
||||||
task.id_target = id_task;
|
|
||||||
queue_results.remove_waiting_task_id(id_task);
|
|
||||||
cancel_tasks.push_back(std::move(task));
|
|
||||||
}
|
|
||||||
// push to beginning of the queue, so it has highest priority
|
|
||||||
queue_tasks.post(std::move(cancel_tasks), true);
|
|
||||||
}
|
|
||||||
|
|
||||||
// receive the results from task(s)
|
|
||||||
void receive_multi_results(
|
|
||||||
const std::unordered_set<int> & id_tasks,
|
|
||||||
const std::function<void(std::vector<server_task_result_ptr>&)> & result_handler,
|
|
||||||
const std::function<void(json)> & error_handler,
|
|
||||||
const std::function<bool()> & is_connection_closed) {
|
|
||||||
std::vector<server_task_result_ptr> results(id_tasks.size());
|
|
||||||
for (int i = 0; i < (int)id_tasks.size(); i++) {
|
|
||||||
server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS);
|
|
||||||
|
|
||||||
if (is_connection_closed()) {
|
|
||||||
cancel_tasks(id_tasks);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (result == nullptr) {
|
|
||||||
i--; // retry
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (result->is_error()) {
|
|
||||||
error_handler(result->to_json());
|
|
||||||
cancel_tasks(id_tasks);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
GGML_ASSERT(
|
|
||||||
dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
|
|
||||||
|| dynamic_cast<server_task_result_embd*>(result.get()) != nullptr
|
|
||||||
|| dynamic_cast<server_task_result_rerank*>(result.get()) != nullptr
|
|
||||||
);
|
|
||||||
const size_t idx = result->get_index();
|
|
||||||
GGML_ASSERT(idx < results.size() && "index out of range");
|
|
||||||
results[idx] = std::move(result);
|
|
||||||
}
|
|
||||||
result_handler(results);
|
|
||||||
}
|
|
||||||
|
|
||||||
// receive the results from task(s), in stream mode
|
|
||||||
void receive_cmpl_results_stream(
|
|
||||||
const std::unordered_set<int> & id_tasks,
|
|
||||||
const std::function<bool(server_task_result_ptr&)> & result_handler,
|
|
||||||
const std::function<void(json)> & error_handler,
|
|
||||||
const std::function<bool()> & is_connection_closed) {
|
|
||||||
size_t n_finished = 0;
|
|
||||||
while (true) {
|
|
||||||
server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS);
|
|
||||||
|
|
||||||
if (is_connection_closed()) {
|
|
||||||
cancel_tasks(id_tasks);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (result == nullptr) {
|
|
||||||
continue; // retry
|
|
||||||
}
|
|
||||||
|
|
||||||
if (result->is_error()) {
|
|
||||||
error_handler(result->to_json());
|
|
||||||
cancel_tasks(id_tasks);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
GGML_ASSERT(
|
|
||||||
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|
|
||||||
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
|
|
||||||
);
|
|
||||||
if (!result_handler(result)) {
|
|
||||||
cancel_tasks(id_tasks);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (result->is_stop()) {
|
|
||||||
if (++n_finished == id_tasks.size()) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// Functions to process the task
|
// Functions to process the task
|
||||||
//
|
//
|
||||||
@@ -4418,6 +4319,104 @@ struct server_context {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// generator-like API for server responses, support pooling connection state and aggregating results
|
||||||
|
struct server_response_reader {
|
||||||
|
std::unordered_set<int> id_tasks;
|
||||||
|
server_context & ctx_server;
|
||||||
|
size_t received_count = 0;
|
||||||
|
bool cancelled = false;
|
||||||
|
|
||||||
|
server_response_reader(server_context & ctx_server) : ctx_server(ctx_server) {}
|
||||||
|
~server_response_reader() {
|
||||||
|
stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
void post_tasks(std::vector<server_task> && tasks) {
|
||||||
|
id_tasks = server_task::get_list_id(tasks);
|
||||||
|
ctx_server.queue_results.add_waiting_tasks(tasks);
|
||||||
|
ctx_server.queue_tasks.post(std::move(tasks));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool has_next() {
|
||||||
|
return !cancelled && received_count < id_tasks.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
// return nullptr if should_stop() is true before receiving a result
|
||||||
|
// note: if one error is received, it will stop further processing and return error result
|
||||||
|
server_task_result_ptr next(const std::function<bool()> & should_stop) {
|
||||||
|
while (true) {
|
||||||
|
server_task_result_ptr result = ctx_server.queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS);
|
||||||
|
if (result == nullptr) {
|
||||||
|
// timeout, check stop condition
|
||||||
|
if (should_stop()) {
|
||||||
|
SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n");
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (result->is_error()) {
|
||||||
|
stop(); // cancel remaining tasks
|
||||||
|
SRV_DBG("%s", "received error result, stopping further processing\n");
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
if (result->is_stop()) {
|
||||||
|
received_count++;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// should not reach here
|
||||||
|
}
|
||||||
|
|
||||||
|
struct batch_response {
|
||||||
|
bool is_terminated = false; // if true, indicates that processing was stopped before all results were received
|
||||||
|
std::vector<server_task_result_ptr> results;
|
||||||
|
server_task_result_ptr error; // nullptr if no error
|
||||||
|
};
|
||||||
|
|
||||||
|
batch_response wait_for_all(const std::function<bool()> & should_stop) {
|
||||||
|
batch_response batch_res;
|
||||||
|
batch_res.results.resize(id_tasks.size());
|
||||||
|
while (has_next()) {
|
||||||
|
auto res = next(should_stop);
|
||||||
|
if (res == nullptr) {
|
||||||
|
batch_res.is_terminated = true;
|
||||||
|
return batch_res;
|
||||||
|
}
|
||||||
|
if (res->is_error()) {
|
||||||
|
batch_res.error = std::move(res);
|
||||||
|
return batch_res;
|
||||||
|
}
|
||||||
|
const size_t idx = res->get_index();
|
||||||
|
GGML_ASSERT(idx < batch_res.results.size() && "index out of range");
|
||||||
|
GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received");
|
||||||
|
batch_res.results[idx] = std::move(res);
|
||||||
|
}
|
||||||
|
return batch_res;
|
||||||
|
}
|
||||||
|
|
||||||
|
void stop() {
|
||||||
|
ctx_server.queue_results.remove_waiting_task_ids(id_tasks);
|
||||||
|
if (has_next() && !cancelled) {
|
||||||
|
// if tasks is not finished yet, cancel them
|
||||||
|
cancelled = true;
|
||||||
|
std::vector<server_task> cancel_tasks;
|
||||||
|
cancel_tasks.reserve(id_tasks.size());
|
||||||
|
for (const auto & id_task : id_tasks) {
|
||||||
|
SRV_WRN("cancel task, id_task = %d\n", id_task);
|
||||||
|
server_task task(SERVER_TASK_TYPE_CANCEL);
|
||||||
|
task.id_target = id_task;
|
||||||
|
ctx_server.queue_results.remove_waiting_task_id(id_task);
|
||||||
|
cancel_tasks.push_back(std::move(task));
|
||||||
|
}
|
||||||
|
// push to beginning of the queue, so it has highest priority
|
||||||
|
ctx_server.queue_tasks.post(std::move(cancel_tasks), true);
|
||||||
|
} else {
|
||||||
|
SRV_DBG("%s", "all tasks already finished, no need to cancel\n");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
static void log_server_request(const httplib::Request & req, const httplib::Response & res) {
|
static void log_server_request(const httplib::Request & req, const httplib::Response & res) {
|
||||||
// skip GH copilot requests when using default port
|
// skip GH copilot requests when using default port
|
||||||
if (req.path == "/v1/health") {
|
if (req.path == "/v1/health") {
|
||||||
@@ -5000,7 +4999,10 @@ int main(int argc, char ** argv) {
|
|||||||
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
|
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
|
||||||
|
|
||||||
auto completion_id = gen_chatcmplid();
|
auto completion_id = gen_chatcmplid();
|
||||||
std::unordered_set<int> task_ids;
|
// need to store the reader as a pointer, so that it won't be destroyed when the handle returns
|
||||||
|
// use shared_ptr as it's shared between the chunked_content_provider() and on_complete()
|
||||||
|
const auto rd = std::make_shared<server_response_reader>(ctx_server);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
std::vector<server_task> tasks;
|
std::vector<server_task> tasks;
|
||||||
|
|
||||||
@@ -5018,17 +5020,8 @@ int main(int argc, char ** argv) {
|
|||||||
// Everything else, including multimodal completions.
|
// Everything else, including multimodal completions.
|
||||||
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
|
inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true);
|
||||||
}
|
}
|
||||||
const size_t n_ctx_slot = ctx_server.slots.front().n_ctx;
|
|
||||||
tasks.reserve(inputs.size());
|
tasks.reserve(inputs.size());
|
||||||
for (size_t i = 0; i < inputs.size(); i++) {
|
for (size_t i = 0; i < inputs.size(); i++) {
|
||||||
auto n_prompt_tokens = inputs[i].size();
|
|
||||||
if (n_prompt_tokens >= n_ctx_slot) {
|
|
||||||
json error_data = format_error_response("the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE);
|
|
||||||
error_data["n_prompt_tokens"] = n_prompt_tokens;
|
|
||||||
error_data["n_ctx"] = n_ctx_slot;
|
|
||||||
res_error(res, error_data);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
server_task task = server_task(type);
|
server_task task = server_task(type);
|
||||||
|
|
||||||
task.id = ctx_server.queue_tasks.get_new_id();
|
task.id = ctx_server.queue_tasks.get_new_id();
|
||||||
@@ -5049,9 +5042,7 @@ int main(int argc, char ** argv) {
|
|||||||
tasks.push_back(std::move(task));
|
tasks.push_back(std::move(task));
|
||||||
}
|
}
|
||||||
|
|
||||||
task_ids = server_task::get_list_id(tasks);
|
rd->post_tasks(std::move(tasks));
|
||||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
|
||||||
ctx_server.queue_tasks.post(std::move(tasks));
|
|
||||||
} catch (const std::exception & e) {
|
} catch (const std::exception & e) {
|
||||||
res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
|
res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST));
|
||||||
return;
|
return;
|
||||||
@@ -5060,54 +5051,95 @@ int main(int argc, char ** argv) {
|
|||||||
bool stream = json_value(data, "stream", false);
|
bool stream = json_value(data, "stream", false);
|
||||||
|
|
||||||
if (!stream) {
|
if (!stream) {
|
||||||
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
// non-stream, wait for the results
|
||||||
if (results.size() == 1) {
|
auto all_results = rd->wait_for_all(is_connection_closed);
|
||||||
// single result
|
if (all_results.is_terminated) {
|
||||||
res_ok(res, results[0]->to_json());
|
return; // connection is closed
|
||||||
|
} else if (all_results.error) {
|
||||||
|
res_error(res, all_results.error->to_json());
|
||||||
|
return;
|
||||||
} else {
|
} else {
|
||||||
// multiple results (multitask)
|
|
||||||
json arr = json::array();
|
json arr = json::array();
|
||||||
for (auto & res : results) {
|
for (auto & res : all_results.results) {
|
||||||
|
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr);
|
||||||
arr.push_back(res->to_json());
|
arr.push_back(res->to_json());
|
||||||
}
|
}
|
||||||
res_ok(res, arr);
|
// if single request, return single object instead of array
|
||||||
|
res_ok(res, arr.size() == 1 ? arr[0] : arr);
|
||||||
}
|
}
|
||||||
}, [&](const json & error_data) {
|
|
||||||
res_error(res, error_data);
|
|
||||||
}, is_connection_closed);
|
|
||||||
|
|
||||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
|
||||||
} else {
|
} else {
|
||||||
const auto chunked_content_provider = [task_ids, &ctx_server, oaicompat](size_t, httplib::DataSink & sink) {
|
// in streaming mode, the first error must be treated as non-stream response
|
||||||
ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool {
|
// this is to match the OAI API behavior
|
||||||
|
// ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309
|
||||||
|
server_task_result_ptr first_result = rd->next(is_connection_closed);
|
||||||
|
if (first_result == nullptr) {
|
||||||
|
return; // connection is closed
|
||||||
|
} else if (first_result->is_error()) {
|
||||||
|
res_error(res, first_result->to_json());
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(
|
||||||
|
dynamic_cast<server_task_result_cmpl_partial*>(first_result.get()) != nullptr
|
||||||
|
|| dynamic_cast<server_task_result_cmpl_final*>(first_result.get()) != nullptr
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// next responses are streamed
|
||||||
|
json first_result_json = first_result->to_json();
|
||||||
|
const auto chunked_content_provider = [first_result_json, rd, oaicompat](size_t, httplib::DataSink & sink) mutable -> bool {
|
||||||
|
// flush the first result as it's not an error
|
||||||
|
if (!first_result_json.empty()) {
|
||||||
|
if (!server_sent_event(sink, first_result_json)) {
|
||||||
|
sink.done();
|
||||||
|
return false; // sending failed, go to on_complete()
|
||||||
|
}
|
||||||
|
first_result_json.clear(); // mark as sent
|
||||||
|
}
|
||||||
|
|
||||||
|
// receive subsequent results
|
||||||
|
auto result = rd->next([&sink]{ return !sink.is_writable(); });
|
||||||
|
if (result == nullptr) {
|
||||||
|
sink.done();
|
||||||
|
return false; // connection is closed, go to on_complete()
|
||||||
|
}
|
||||||
|
|
||||||
|
// send the results
|
||||||
json res_json = result->to_json();
|
json res_json = result->to_json();
|
||||||
if (res_json.is_array()) {
|
bool ok = false;
|
||||||
for (const auto & res : res_json) {
|
if (result->is_error()) {
|
||||||
if (!server_sent_event(sink, res)) {
|
ok = server_sent_event(sink, json {{ "error", result->to_json() }});
|
||||||
// sending failed (HTTP connection closed), cancel the generation
|
sink.done();
|
||||||
return false;
|
return false; // go to on_complete()
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
} else {
|
} else {
|
||||||
return server_sent_event(sink, res_json);
|
GGML_ASSERT(
|
||||||
|
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|
||||||
|
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
|
||||||
|
);
|
||||||
|
ok = server_sent_event(sink, res_json);
|
||||||
}
|
}
|
||||||
}, [&](const json & error_data) {
|
|
||||||
server_sent_event(sink, json{{"error", error_data}});
|
if (!ok) {
|
||||||
}, [&sink]() {
|
sink.done();
|
||||||
// note: do not use req.is_connection_closed here because req is already destroyed
|
return false; // sending failed, go to on_complete()
|
||||||
return !sink.is_writable();
|
}
|
||||||
});
|
|
||||||
|
// check if there is more data
|
||||||
|
if (!rd->has_next()) {
|
||||||
if (oaicompat != OAICOMPAT_TYPE_NONE) {
|
if (oaicompat != OAICOMPAT_TYPE_NONE) {
|
||||||
static const std::string ev_done = "data: [DONE]\n\n";
|
static const std::string ev_done = "data: [DONE]\n\n";
|
||||||
sink.write(ev_done.data(), ev_done.size());
|
sink.write(ev_done.data(), ev_done.size());
|
||||||
}
|
}
|
||||||
sink.done();
|
sink.done();
|
||||||
return false;
|
return false; // no more data, go to on_complete()
|
||||||
|
}
|
||||||
|
|
||||||
|
// has next data, continue
|
||||||
|
return true;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto on_complete = [task_ids, &ctx_server] (bool) {
|
auto on_complete = [rd](bool) {
|
||||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
rd->stop();
|
||||||
};
|
};
|
||||||
|
|
||||||
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
|
||||||
@@ -5401,8 +5433,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// create and queue the task
|
// create and queue the task
|
||||||
json responses = json::array();
|
json responses = json::array();
|
||||||
bool error = false;
|
server_response_reader rd(ctx_server);
|
||||||
std::unordered_set<int> task_ids;
|
|
||||||
{
|
{
|
||||||
std::vector<server_task> tasks;
|
std::vector<server_task> tasks;
|
||||||
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
|
||||||
@@ -5418,27 +5449,23 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
tasks.push_back(std::move(task));
|
tasks.push_back(std::move(task));
|
||||||
}
|
}
|
||||||
|
rd.post_tasks(std::move(tasks));
|
||||||
task_ids = server_task::get_list_id(tasks);
|
|
||||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
|
||||||
ctx_server.queue_tasks.post(std::move(tasks));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// get the result
|
// wait for the results
|
||||||
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
auto all_results = rd.wait_for_all(req.is_connection_closed);
|
||||||
for (auto & res : results) {
|
|
||||||
|
// collect results
|
||||||
|
if (all_results.is_terminated) {
|
||||||
|
return; // connection is closed
|
||||||
|
} else if (all_results.error) {
|
||||||
|
res_error(res, all_results.error->to_json());
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
for (auto & res : all_results.results) {
|
||||||
GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
|
GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr);
|
||||||
responses.push_back(res->to_json());
|
responses.push_back(res->to_json());
|
||||||
}
|
}
|
||||||
}, [&](const json & error_data) {
|
|
||||||
res_error(res, error_data);
|
|
||||||
error = true;
|
|
||||||
}, req.is_connection_closed);
|
|
||||||
|
|
||||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
|
||||||
|
|
||||||
if (error) {
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// write JSON response
|
// write JSON response
|
||||||
@@ -5492,8 +5519,7 @@ int main(int argc, char ** argv) {
|
|||||||
|
|
||||||
// create and queue the task
|
// create and queue the task
|
||||||
json responses = json::array();
|
json responses = json::array();
|
||||||
bool error = false;
|
server_response_reader rd(ctx_server);
|
||||||
std::unordered_set<int> task_ids;
|
|
||||||
{
|
{
|
||||||
std::vector<server_task> tasks;
|
std::vector<server_task> tasks;
|
||||||
tasks.reserve(documents.size());
|
tasks.reserve(documents.size());
|
||||||
@@ -5505,24 +5531,23 @@ int main(int argc, char ** argv) {
|
|||||||
task.tokens = std::move(tmp);
|
task.tokens = std::move(tmp);
|
||||||
tasks.push_back(std::move(task));
|
tasks.push_back(std::move(task));
|
||||||
}
|
}
|
||||||
|
rd.post_tasks(std::move(tasks));
|
||||||
task_ids = server_task::get_list_id(tasks);
|
|
||||||
ctx_server.queue_results.add_waiting_tasks(tasks);
|
|
||||||
ctx_server.queue_tasks.post(std::move(tasks));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx_server.receive_multi_results(task_ids, [&](std::vector<server_task_result_ptr> & results) {
|
// wait for the results
|
||||||
for (auto & res : results) {
|
auto all_results = rd.wait_for_all(req.is_connection_closed);
|
||||||
|
|
||||||
|
// collect results
|
||||||
|
if (all_results.is_terminated) {
|
||||||
|
return; // connection is closed
|
||||||
|
} else if (all_results.error) {
|
||||||
|
res_error(res, all_results.error->to_json());
|
||||||
|
return;
|
||||||
|
} else {
|
||||||
|
for (auto & res : all_results.results) {
|
||||||
GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
|
GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr);
|
||||||
responses.push_back(res->to_json());
|
responses.push_back(res->to_json());
|
||||||
}
|
}
|
||||||
}, [&](const json & error_data) {
|
|
||||||
res_error(res, error_data);
|
|
||||||
error = true;
|
|
||||||
}, req.is_connection_closed);
|
|
||||||
|
|
||||||
if (error) {
|
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// write JSON response
|
// write JSON response
|
||||||
|
|||||||
@@ -453,15 +453,29 @@ static std::string tokens_to_output_formatted_string(const llama_context * ctx,
|
|||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// note: if data is a json array, it will be sent as multiple events, one per item
|
||||||
static bool server_sent_event(httplib::DataSink & sink, const json & data) {
|
static bool server_sent_event(httplib::DataSink & sink, const json & data) {
|
||||||
|
static auto send_single = [](httplib::DataSink & sink, const json & data) -> bool {
|
||||||
const std::string str =
|
const std::string str =
|
||||||
"data: " +
|
"data: " +
|
||||||
data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
data.dump(-1, ' ', false, json::error_handler_t::replace) +
|
||||||
"\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row).
|
"\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row).
|
||||||
|
|
||||||
LOG_DBG("data stream, to_send: %s", str.c_str());
|
LOG_DBG("data stream, to_send: %s", str.c_str());
|
||||||
|
|
||||||
return sink.write(str.c_str(), str.size());
|
return sink.write(str.c_str(), str.size());
|
||||||
|
};
|
||||||
|
|
||||||
|
if (data.is_array()) {
|
||||||
|
for (const auto & item : data) {
|
||||||
|
if (!send_single(sink, item)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return send_single(sink, data);
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
|
|||||||
Reference in New Issue
Block a user