diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 0cf2f767b3..0b3c77879c 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -684,7 +684,7 @@ struct server_task_result { } virtual bool is_stop() { // only used by server_task_result_cmpl_* - return false; + return true; } virtual int get_index() { return -1; @@ -3238,105 +3238,6 @@ struct server_context { queue_results.send(std::move(res)); } - // - // Functions to create new task(s) and receive result(s) - // - - void cancel_tasks(const std::unordered_set & id_tasks) { - std::vector cancel_tasks; - cancel_tasks.reserve(id_tasks.size()); - for (const auto & id_task : id_tasks) { - SRV_WRN("cancel task, id_task = %d\n", id_task); - - server_task task(SERVER_TASK_TYPE_CANCEL); - task.id_target = id_task; - queue_results.remove_waiting_task_id(id_task); - cancel_tasks.push_back(std::move(task)); - } - // push to beginning of the queue, so it has highest priority - queue_tasks.post(std::move(cancel_tasks), true); - } - - // receive the results from task(s) - void receive_multi_results( - const std::unordered_set & id_tasks, - const std::function&)> & result_handler, - const std::function & error_handler, - const std::function & is_connection_closed) { - std::vector results(id_tasks.size()); - for (int i = 0; i < (int)id_tasks.size(); i++) { - server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS); - - if (is_connection_closed()) { - cancel_tasks(id_tasks); - return; - } - - if (result == nullptr) { - i--; // retry - continue; - } - - if (result->is_error()) { - error_handler(result->to_json()); - cancel_tasks(id_tasks); - return; - } - - GGML_ASSERT( - dynamic_cast(result.get()) != nullptr - || dynamic_cast(result.get()) != nullptr - || dynamic_cast(result.get()) != nullptr - ); - const size_t idx = result->get_index(); - GGML_ASSERT(idx < results.size() && "index out of range"); - results[idx] = std::move(result); - } - result_handler(results); - } - - // receive the results from task(s), in stream mode - void receive_cmpl_results_stream( - const std::unordered_set & id_tasks, - const std::function & result_handler, - const std::function & error_handler, - const std::function & is_connection_closed) { - size_t n_finished = 0; - while (true) { - server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS); - - if (is_connection_closed()) { - cancel_tasks(id_tasks); - return; - } - - if (result == nullptr) { - continue; // retry - } - - if (result->is_error()) { - error_handler(result->to_json()); - cancel_tasks(id_tasks); - return; - } - - GGML_ASSERT( - dynamic_cast(result.get()) != nullptr - || dynamic_cast(result.get()) != nullptr - ); - if (!result_handler(result)) { - cancel_tasks(id_tasks); - break; - } - - if (result->is_stop()) { - if (++n_finished == id_tasks.size()) { - break; - } - } - } - } - // // Functions to process the task // @@ -4418,6 +4319,104 @@ struct server_context { } }; +// generator-like API for server responses, support pooling connection state and aggregating results +struct server_response_reader { + std::unordered_set id_tasks; + server_context & ctx_server; + size_t received_count = 0; + bool cancelled = false; + + server_response_reader(server_context & ctx_server) : ctx_server(ctx_server) {} + ~server_response_reader() { + stop(); + } + + void post_tasks(std::vector && tasks) { + id_tasks = server_task::get_list_id(tasks); + ctx_server.queue_results.add_waiting_tasks(tasks); + ctx_server.queue_tasks.post(std::move(tasks)); + } + + bool has_next() { + return !cancelled && received_count < id_tasks.size(); + } + + // return nullptr if should_stop() is true before receiving a result + // note: if one error is received, it will stop further processing and return error result + server_task_result_ptr next(const std::function & should_stop) { + while (true) { + server_task_result_ptr result = ctx_server.queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS); + if (result == nullptr) { + // timeout, check stop condition + if (should_stop()) { + SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n"); + return nullptr; + } + } else { + if (result->is_error()) { + stop(); // cancel remaining tasks + SRV_DBG("%s", "received error result, stopping further processing\n"); + return result; + } + if (result->is_stop()) { + received_count++; + } + return result; + } + } + + // should not reach here + } + + struct batch_response { + bool is_terminated = false; // if true, indicates that processing was stopped before all results were received + std::vector results; + server_task_result_ptr error; // nullptr if no error + }; + + batch_response wait_for_all(const std::function & should_stop) { + batch_response batch_res; + batch_res.results.resize(id_tasks.size()); + while (has_next()) { + auto res = next(should_stop); + if (res == nullptr) { + batch_res.is_terminated = true; + return batch_res; + } + if (res->is_error()) { + batch_res.error = std::move(res); + return batch_res; + } + const size_t idx = res->get_index(); + GGML_ASSERT(idx < batch_res.results.size() && "index out of range"); + GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received"); + batch_res.results[idx] = std::move(res); + } + return batch_res; + } + + void stop() { + ctx_server.queue_results.remove_waiting_task_ids(id_tasks); + if (has_next() && !cancelled) { + // if tasks is not finished yet, cancel them + cancelled = true; + std::vector cancel_tasks; + cancel_tasks.reserve(id_tasks.size()); + for (const auto & id_task : id_tasks) { + SRV_WRN("cancel task, id_task = %d\n", id_task); + server_task task(SERVER_TASK_TYPE_CANCEL); + task.id_target = id_task; + ctx_server.queue_results.remove_waiting_task_id(id_task); + cancel_tasks.push_back(std::move(task)); + } + // push to beginning of the queue, so it has highest priority + ctx_server.queue_tasks.post(std::move(cancel_tasks), true); + } else { + SRV_DBG("%s", "all tasks already finished, no need to cancel\n"); + } + } +}; + static void log_server_request(const httplib::Request & req, const httplib::Response & res) { // skip GH copilot requests when using default port if (req.path == "/v1/health") { @@ -5000,7 +4999,10 @@ int main(int argc, char ** argv) { GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); auto completion_id = gen_chatcmplid(); - std::unordered_set task_ids; + // need to store the reader as a pointer, so that it won't be destroyed when the handle returns + // use shared_ptr as it's shared between the chunked_content_provider() and on_complete() + const auto rd = std::make_shared(ctx_server); + try { std::vector tasks; @@ -5018,17 +5020,8 @@ int main(int argc, char ** argv) { // Everything else, including multimodal completions. inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); } - const size_t n_ctx_slot = ctx_server.slots.front().n_ctx; tasks.reserve(inputs.size()); for (size_t i = 0; i < inputs.size(); i++) { - auto n_prompt_tokens = inputs[i].size(); - if (n_prompt_tokens >= n_ctx_slot) { - json error_data = format_error_response("the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE); - error_data["n_prompt_tokens"] = n_prompt_tokens; - error_data["n_ctx"] = n_ctx_slot; - res_error(res, error_data); - return; - } server_task task = server_task(type); task.id = ctx_server.queue_tasks.get_new_id(); @@ -5049,9 +5042,7 @@ int main(int argc, char ** argv) { tasks.push_back(std::move(task)); } - task_ids = server_task::get_list_id(tasks); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(std::move(tasks)); + rd->post_tasks(std::move(tasks)); } catch (const std::exception & e) { res_error(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); return; @@ -5060,54 +5051,95 @@ int main(int argc, char ** argv) { bool stream = json_value(data, "stream", false); if (!stream) { - ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { - if (results.size() == 1) { - // single result - res_ok(res, results[0]->to_json()); - } else { - // multiple results (multitask) - json arr = json::array(); - for (auto & res : results) { - arr.push_back(res->to_json()); - } - res_ok(res, arr); + // non-stream, wait for the results + auto all_results = rd->wait_for_all(is_connection_closed); + if (all_results.is_terminated) { + return; // connection is closed + } else if (all_results.error) { + res_error(res, all_results.error->to_json()); + return; + } else { + json arr = json::array(); + for (auto & res : all_results.results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + arr.push_back(res->to_json()); } - }, [&](const json & error_data) { - res_error(res, error_data); - }, is_connection_closed); + // if single request, return single object instead of array + res_ok(res, arr.size() == 1 ? arr[0] : arr); + } - ctx_server.queue_results.remove_waiting_task_ids(task_ids); } else { - const auto chunked_content_provider = [task_ids, &ctx_server, oaicompat](size_t, httplib::DataSink & sink) { - ctx_server.receive_cmpl_results_stream(task_ids, [&](server_task_result_ptr & result) -> bool { - json res_json = result->to_json(); - if (res_json.is_array()) { - for (const auto & res : res_json) { - if (!server_sent_event(sink, res)) { - // sending failed (HTTP connection closed), cancel the generation - return false; - } - } - return true; - } else { - return server_sent_event(sink, res_json); + // in streaming mode, the first error must be treated as non-stream response + // this is to match the OAI API behavior + // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309 + server_task_result_ptr first_result = rd->next(is_connection_closed); + if (first_result == nullptr) { + return; // connection is closed + } else if (first_result->is_error()) { + res_error(res, first_result->to_json()); + return; + } else { + GGML_ASSERT( + dynamic_cast(first_result.get()) != nullptr + || dynamic_cast(first_result.get()) != nullptr + ); + } + + // next responses are streamed + json first_result_json = first_result->to_json(); + const auto chunked_content_provider = [first_result_json, rd, oaicompat](size_t, httplib::DataSink & sink) mutable -> bool { + // flush the first result as it's not an error + if (!first_result_json.empty()) { + if (!server_sent_event(sink, first_result_json)) { + sink.done(); + return false; // sending failed, go to on_complete() } - }, [&](const json & error_data) { - server_sent_event(sink, json{{"error", error_data}}); - }, [&sink]() { - // note: do not use req.is_connection_closed here because req is already destroyed - return !sink.is_writable(); - }); - if (oaicompat != OAICOMPAT_TYPE_NONE) { - static const std::string ev_done = "data: [DONE]\n\n"; - sink.write(ev_done.data(), ev_done.size()); + first_result_json.clear(); // mark as sent } - sink.done(); - return false; + + // receive subsequent results + auto result = rd->next([&sink]{ return !sink.is_writable(); }); + if (result == nullptr) { + sink.done(); + return false; // connection is closed, go to on_complete() + } + + // send the results + json res_json = result->to_json(); + bool ok = false; + if (result->is_error()) { + ok = server_sent_event(sink, json {{ "error", result->to_json() }}); + sink.done(); + return false; // go to on_complete() + } else { + GGML_ASSERT( + dynamic_cast(result.get()) != nullptr + || dynamic_cast(result.get()) != nullptr + ); + ok = server_sent_event(sink, res_json); + } + + if (!ok) { + sink.done(); + return false; // sending failed, go to on_complete() + } + + // check if there is more data + if (!rd->has_next()) { + if (oaicompat != OAICOMPAT_TYPE_NONE) { + static const std::string ev_done = "data: [DONE]\n\n"; + sink.write(ev_done.data(), ev_done.size()); + } + sink.done(); + return false; // no more data, go to on_complete() + } + + // has next data, continue + return true; }; - auto on_complete = [task_ids, &ctx_server] (bool) { - ctx_server.queue_results.remove_waiting_task_ids(task_ids); + auto on_complete = [rd](bool) { + rd->stop(); }; res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); @@ -5401,8 +5433,7 @@ int main(int argc, char ** argv) { // create and queue the task json responses = json::array(); - bool error = false; - std::unordered_set task_ids; + server_response_reader rd(ctx_server); { std::vector tasks; for (size_t i = 0; i < tokenized_prompts.size(); i++) { @@ -5418,27 +5449,23 @@ int main(int argc, char ** argv) { tasks.push_back(std::move(task)); } - - task_ids = server_task::get_list_id(tasks); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(std::move(tasks)); + rd.post_tasks(std::move(tasks)); } - // get the result - ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { - for (auto & res : results) { + // wait for the results + auto all_results = rd.wait_for_all(req.is_connection_closed); + + // collect results + if (all_results.is_terminated) { + return; // connection is closed + } else if (all_results.error) { + res_error(res, all_results.error->to_json()); + return; + } else { + for (auto & res : all_results.results) { GGML_ASSERT(dynamic_cast(res.get()) != nullptr); responses.push_back(res->to_json()); } - }, [&](const json & error_data) { - res_error(res, error_data); - error = true; - }, req.is_connection_closed); - - ctx_server.queue_results.remove_waiting_task_ids(task_ids); - - if (error) { - return; } // write JSON response @@ -5492,8 +5519,7 @@ int main(int argc, char ** argv) { // create and queue the task json responses = json::array(); - bool error = false; - std::unordered_set task_ids; + server_response_reader rd(ctx_server); { std::vector tasks; tasks.reserve(documents.size()); @@ -5505,24 +5531,23 @@ int main(int argc, char ** argv) { task.tokens = std::move(tmp); tasks.push_back(std::move(task)); } - - task_ids = server_task::get_list_id(tasks); - ctx_server.queue_results.add_waiting_tasks(tasks); - ctx_server.queue_tasks.post(std::move(tasks)); + rd.post_tasks(std::move(tasks)); } - ctx_server.receive_multi_results(task_ids, [&](std::vector & results) { - for (auto & res : results) { + // wait for the results + auto all_results = rd.wait_for_all(req.is_connection_closed); + + // collect results + if (all_results.is_terminated) { + return; // connection is closed + } else if (all_results.error) { + res_error(res, all_results.error->to_json()); + return; + } else { + for (auto & res : all_results.results) { GGML_ASSERT(dynamic_cast(res.get()) != nullptr); responses.push_back(res->to_json()); } - }, [&](const json & error_data) { - res_error(res, error_data); - error = true; - }, req.is_connection_closed); - - if (error) { - return; } // write JSON response diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index e9d4431ddf..b1ecc5af5e 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -453,15 +453,29 @@ static std::string tokens_to_output_formatted_string(const llama_context * ctx, return out; } +// note: if data is a json array, it will be sent as multiple events, one per item static bool server_sent_event(httplib::DataSink & sink, const json & data) { - const std::string str = - "data: " + - data.dump(-1, ' ', false, json::error_handler_t::replace) + - "\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row). + static auto send_single = [](httplib::DataSink & sink, const json & data) -> bool { + const std::string str = + "data: " + + data.dump(-1, ' ', false, json::error_handler_t::replace) + + "\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row). - LOG_DBG("data stream, to_send: %s", str.c_str()); + LOG_DBG("data stream, to_send: %s", str.c_str()); + return sink.write(str.c_str(), str.size()); + }; - return sink.write(str.c_str(), str.size()); + if (data.is_array()) { + for (const auto & item : data) { + if (!send_single(sink, item)) { + return false; + } + } + } else { + return send_single(sink, data); + } + + return true; } //