diff --git a/tools/server/CMakeLists.txt b/tools/server/CMakeLists.txt index c801e84c3d..1fccfdd17f 100644 --- a/tools/server/CMakeLists.txt +++ b/tools/server/CMakeLists.txt @@ -14,6 +14,8 @@ endif() set(TARGET_SRCS server.cpp utils.hpp + server-http.cpp + server-http.h ) set(PUBLIC_ASSETS index.html.gz diff --git a/tools/server/server-http.cpp b/tools/server/server-http.cpp new file mode 100644 index 0000000000..9751f290c0 --- /dev/null +++ b/tools/server/server-http.cpp @@ -0,0 +1,386 @@ +#include "utils.hpp" +#include "common.h" +#include "server-http.h" + +#include + +#include +#include +#include + +// auto generated files (see README.md for details) +#include "index.html.gz.hpp" +#include "loading.html.hpp" + +// +// HTTP implementation using cpp-httplib +// + +class server_http_context::Impl { +public: + std::unique_ptr srv; +}; + +server_http_context::server_http_context() + : pimpl(std::make_unique()) +{} + +server_http_context::~server_http_context() = default; + +static void log_server_request(const httplib::Request & req, const httplib::Response & res) { + // skip GH copilot requests when using default port + if (req.path == "/v1/health") { + return; + } + + // reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch + + SRV_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status); + + SRV_DBG("request: %s\n", req.body.c_str()); + SRV_DBG("response: %s\n", res.body.c_str()); +} + +bool server_http_context::init(const common_params & params) { + path_prefix = params.api_prefix; + port = params.port; + hostname = params.hostname; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (params.ssl_file_key != "" && params.ssl_file_cert != "") { + LOG_INF("Running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str()); + svr.reset( + new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str()) + ); + } else { + LOG_INF("Running without SSL\n"); + svr.reset(new httplib::Server()); + } +#else + if (params.ssl_file_key != "" && params.ssl_file_cert != "") { + LOG_ERR("Server is built without SSL support\n"); + return false; + } + pimpl->srv.reset(new httplib::Server()); +#endif + + auto & srv = pimpl->srv; + srv->set_default_headers({{"Server", "llama.cpp"}}); + srv->set_logger(log_server_request); + srv->set_exception_handler([](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) { + // this is fail-safe; exceptions should already handled by `ex_wrapper` + + std::string message; + try { + std::rethrow_exception(ep); + } catch (const std::exception & e) { + message = e.what(); + } catch (...) { + message = "Unknown Exception"; + } + + res.status = 500; + res.set_content(message, "text/plain"); + LOG_ERR("got exception: %s\n", message.c_str()); + }); + + srv->set_error_handler([](const httplib::Request &, httplib::Response & res) { + if (res.status == 404) { + res.set_content( + safe_json_to_str(json { + {"error", { + {"message", "File Not Found"}, + {"type", "not_found_error"}, + {"code", 404} + }} + }), + "application/json; charset=utf-8" + ); + } + // for other error codes, we skip processing here because it's already done by res->error() + }); + + // set timeouts and change hostname and port + srv->set_read_timeout (params.timeout_read); + srv->set_write_timeout(params.timeout_write); + + if (params.api_keys.size() == 1) { + auto key = params.api_keys[0]; + std::string substr = key.substr(std::max((int)(key.length() - 4), 0)); + LOG_INF("%s: api_keys: ****%s\n", __func__, substr.c_str()); + } else if (params.api_keys.size() > 1) { + LOG_INF("%s: api_keys: %zu keys loaded\n", __func__, params.api_keys.size()); + } + + // + // Middlewares + // + + auto middleware_validate_api_key = [api_keys = params.api_keys](const httplib::Request & req, httplib::Response & res) { + static const std::unordered_set public_endpoints = { + "/health", + "/v1/health", + "/models", + "/v1/models", + "/api/tags" + }; + + // If API key is not set, skip validation + if (api_keys.empty()) { + return true; + } + + // If path is public or is static file, skip validation + if (public_endpoints.find(req.path) != public_endpoints.end() || req.path == "/") { + return true; + } + + // Check for API key in the header + auto auth_header = req.get_header_value("Authorization"); + + std::string prefix = "Bearer "; + if (auth_header.substr(0, prefix.size()) == prefix) { + std::string received_api_key = auth_header.substr(prefix.size()); + if (std::find(api_keys.begin(), api_keys.end(), received_api_key) != api_keys.end()) { + return true; // API key is valid + } + } + + // API key is invalid or not provided + res.status = 401; + res.set_content( + safe_json_to_str(json { + {"error", { + {"message", "Invalid API Key"}, + {"type", "authentication_error"}, + {"code", 401} + }} + }), + "application/json; charset=utf-8" + ); + + LOG_WRN("Unauthorized: Invalid API Key\n"); + + return false; + }; + + auto middleware_server_state = [this](const httplib::Request & req, httplib::Response & res) { + bool ready = is_ready.load(); + if (!ready) { + auto tmp = string_split(req.path, '.'); + if (req.path == "/" || tmp.back() == "html") { + res.set_content(reinterpret_cast(loading_html), loading_html_len, "text/html; charset=utf-8"); + res.status = 503; + } else if (req.path == "/models" || req.path == "/v1/models" || req.path == "/api/tags") { + // allow the models endpoint to be accessed during loading + return true; + } else { + res.status = 503; + res.set_content( + safe_json_to_str(json { + {"error", { + {"message", "Loading model"}, + {"type", "unavailable_error"}, + {"code", 503} + }} + }), + "application/json; charset=utf-8" + ); + } + return false; + } + return true; + }; + + // register server middlewares + srv->set_pre_routing_handler([middleware_validate_api_key, middleware_server_state](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + // If this is OPTIONS request, skip validation because browsers don't include Authorization header + if (req.method == "OPTIONS") { + res.set_header("Access-Control-Allow-Credentials", "true"); + res.set_header("Access-Control-Allow-Methods", "GET, POST"); + res.set_header("Access-Control-Allow-Headers", "*"); + res.set_content("", "text/html"); // blank response, no data + return httplib::Server::HandlerResponse::Handled; // skip further processing + } + if (!middleware_server_state(req, res)) { + return httplib::Server::HandlerResponse::Handled; + } + if (!middleware_validate_api_key(req, res)) { + return httplib::Server::HandlerResponse::Handled; + } + return httplib::Server::HandlerResponse::Unhandled; + }); + + int n_threads_http = params.n_threads_http; + if (n_threads_http < 1) { + // +2 threads for monitoring endpoints + n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1); + } + LOG_INF("%s: using %d threads for HTTP server\n", __func__, n_threads_http); + srv->new_task_queue = [n_threads_http] { return new httplib::ThreadPool(n_threads_http); }; + + // + // Web UI setup + // + + if (!params.webui) { + LOG_INF("Web UI is disabled\n"); + } else { + // register static assets routes + if (!params.public_path.empty()) { + // Set the base directory for serving static files + bool is_found = srv->set_mount_point(params.api_prefix + "/", params.public_path); + if (!is_found) { + LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str()); + return 1; + } + } else { + // using embedded static index.html + srv->Get(params.api_prefix + "/", [](const httplib::Request & req, httplib::Response & res) { + if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) { + res.set_content("Error: gzip is not supported by this browser", "text/plain"); + } else { + res.set_header("Content-Encoding", "gzip"); + // COEP and COOP headers, required by pyodide (python interpreter) + res.set_header("Cross-Origin-Embedder-Policy", "require-corp"); + res.set_header("Cross-Origin-Opener-Policy", "same-origin"); + res.set_content(reinterpret_cast(index_html_gz), index_html_gz_len, "text/html; charset=utf-8"); + } + return false; + }); + } + } + return true; +} + +bool server_http_context::start() { + // Bind and listen + + auto & srv = pimpl->srv; + bool was_bound = false; + bool is_sock = false; + if (string_ends_with(std::string(hostname), ".sock")) { + is_sock = true; + LOG_INF("%s: setting address family to AF_UNIX\n", __func__); + srv->set_address_family(AF_UNIX); + // bind_to_port requires a second arg, any value other than 0 should + // simply get ignored + was_bound = srv->bind_to_port(hostname, 8080); + } else { + LOG_INF("%s: binding port with default address family\n", __func__); + // bind HTTP listen port + if (port == 0) { + int bound_port = srv->bind_to_any_port(hostname); + was_bound = (bound_port >= 0); + if (was_bound) { + port = bound_port; + } + } else { + was_bound = srv->bind_to_port(hostname, port); + } + } + + if (!was_bound) { + LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, hostname.c_str(), port); + return false; + } + + // run the HTTP server in a thread + thread = std::thread([this]() { pimpl->srv->listen_after_bind(); }); + srv->wait_until_ready(); + + listening_address = is_sock ? string_format("unix://%s", hostname.c_str()) + : string_format("http://%s:%d", hostname.c_str(), port); + return true; +} + +void server_http_context::stop() const { + if (pimpl->srv) { + pimpl->srv->stop(); + } +} + +static void set_headers(httplib::Response & res, const std::map & headers) { + for (const auto & [key, value] : headers) { + res.set_header(key, value); + } +} + +static std::map get_params(const httplib::Request & req) { + std::map params; + for (const auto & [key, value] : req.params) { + params[key] = value; + } + for (const auto & [key, value] : req.path_params) { + params[key] = value; + } + return params; +} + +static std::map get_headers(const httplib::Request & req) { + std::map headers; + for (const auto & [key, value] : req.headers) { + headers[key] = value; + } + return headers; +} + +static void process_handler_response(server_http_res_ptr & response, httplib::Response & res) { + if (response->is_stream()) { + res.status = response->status; + set_headers(res, response->headers); + std::string content_type = response->content_type; + // convert to shared_ptr as both chunked_content_provider() and on_complete() need to use it + std::shared_ptr r_ptr = std::move(response); + const auto chunked_content_provider = [response = r_ptr](size_t, httplib::DataSink & sink) -> bool { + std::string chunk; + bool has_next = response->next(chunk); + if (!chunk.empty()) { + // TODO: maybe handle sink.write unsuccessful? for now, we rely on is_connection_closed() + sink.write(chunk.data(), chunk.size()); + SRV_DBG("http: streamed chunk: %s\n", chunk.c_str()); + } + if (!has_next) { + sink.done(); + SRV_DBG("%s", "http: stream ended\n"); + } + return has_next; + }; + const auto on_complete = [response = r_ptr](bool) mutable { + response.reset(); // trigger the destruction of the response object + }; + res.set_chunked_content_provider(content_type, chunked_content_provider, on_complete); + } else { + res.status = response->status; + set_headers(res, response->headers); + res.set_content(response->data, response->content_type); + } +} + +void server_http_context::get(const std::string & path, const server_http_context::handler_t & handler) const { + pimpl->srv->Get(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) { + server_http_res_ptr response = handler(server_http_req{ + get_params(req), + get_headers(req), + req.path, + req.body, + req.is_connection_closed + }); + process_handler_response(response, res); + }); +} + +void server_http_context::post(const std::string & path, const server_http_context::handler_t & handler) const { + pimpl->srv->Post(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) { + server_http_res_ptr response = handler(server_http_req{ + get_params(req), + get_headers(req), + req.path, + req.body, + req.is_connection_closed + }); + process_handler_response(response, res); + }); +} + diff --git a/tools/server/server-http.h b/tools/server/server-http.h new file mode 100644 index 0000000000..24c0b40117 --- /dev/null +++ b/tools/server/server-http.h @@ -0,0 +1,78 @@ +#pragma once + +#include +#include +#include +#include +#include + +struct common_params; + +// generator-like API for HTTP response generation +// this object response with one of the 2 modes: +// 1) normal response: `data` contains the full response body +// 2) streaming response: each call to next(output) generates the next chunk +// when next(output) returns false, no more data after the current chunk +// note: some chunks can be empty, in which case no data is sent for that chunk +struct server_http_res { + std::string content_type = "application/json; charset=utf-8"; + int status = 200; + std::string data; + std::map headers; + + // TODO: move this to a virtual function once we have proper polymorphism support + std::function next = nullptr; + bool is_stream() const { + return next != nullptr; + } + + virtual ~server_http_res() = default; +}; + +// unique pointer, used by set_chunked_content_provider +// httplib requires the stream provider to be stored in heap +using server_http_res_ptr = std::unique_ptr; + +struct server_http_req { + std::map params; // path_params + query_params + std::map headers; // reserved for future use + std::string path; // reserved for future use + std::string body; + const std::function & should_stop; + + std::string get_param(const std::string & key, const std::string & def = "") const { + auto it = params.find(key); + if (it != params.end()) { + return it->second; + } + return def; + } +}; + +struct server_http_context { + class Impl; + std::unique_ptr pimpl; + + std::thread thread; // server thread + std::atomic is_ready = false; + + std::string path_prefix; + std::string hostname; + int port; + + server_http_context(); + ~server_http_context(); + + bool init(const common_params & params); + bool start(); + void stop() const; + + // note: the handler should never throw exceptions + using handler_t = std::function; + + void get(const std::string & path, const handler_t & handler) const; + void post(const std::string & path, const handler_t & handler) const; + + // for debugging + std::string listening_address; +}; diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 0fc3cf9195..3750c8fdb6 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1,5 +1,6 @@ #include "chat.h" #include "utils.hpp" +#include "server-http.h" #include "arg.h" #include "common.h" @@ -10,13 +11,6 @@ #include "speculative.h" #include "mtmd.h" -// mime type for sending response -#define MIMETYPE_JSON "application/json; charset=utf-8" - -// auto generated files (see README.md for details) -#include "index.html.gz.hpp" -#include "loading.html.hpp" - #include #include #include @@ -25,11 +19,20 @@ #include #include #include +#include #include #include -#include #include +// fix problem with std::min and std::max +#if defined(_WIN32) +#define WIN32_LEAN_AND_MEAN +#ifndef NOMINMAX +# define NOMINMAX +#endif +#include +#endif + using json = nlohmann::ordered_json; constexpr int HTTP_POLLING_SECONDS = 1; @@ -1671,7 +1674,7 @@ struct server_slot { server_prompt prompt; void prompt_save(server_prompt_cache & prompt_cache) const { - assert(prompt.data.size() == 0); + GGML_ASSERT(prompt.data.size() == 0); const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0); @@ -2380,6 +2383,7 @@ struct server_context { llama_batch_free(batch); } + // load the model and initialize llama_context bool load_model(const common_params & params) { SRV_INF("loading model '%s'\n", params.model.path.c_str()); @@ -2499,6 +2503,7 @@ struct server_context { return true; } + // initialize slots and server-related data void init() { SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); @@ -2598,6 +2603,11 @@ struct server_context { /* allow_audio */ mctx ? mtmd_support_audio (mctx) : false, /* enable_thinking */ enable_thinking, }; + + // print sample chat example to make it clear which template is used + LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, + common_chat_templates_source(chat_templates.get()), + common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str()); } server_slot * get_slot_by_id(int id) { @@ -4325,6 +4335,7 @@ struct server_context { } }; + // generator-like API for server responses, support pooling connection state and aggregating results struct server_response_reader { std::unordered_set id_tasks; @@ -4343,7 +4354,7 @@ struct server_response_reader { ctx_server.queue_tasks.post(std::move(tasks)); } - bool has_next() { + bool has_next() const { return !cancelled && received_count < id_tasks.size(); } @@ -4423,281 +4434,46 @@ struct server_response_reader { } }; -static void log_server_request(const httplib::Request & req, const httplib::Response & res) { - // skip GH copilot requests when using default port - if (req.path == "/v1/health") { - return; +// generator-like API for HTTP response generation +struct server_res_generator : server_http_res { + server_response_reader rd; + server_res_generator(server_context & ctx_server_) : rd(ctx_server_) {} + void ok(const json & response_data) { + status = 200; + data = safe_json_to_str(response_data); } - - // reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch - - SRV_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status); - - SRV_DBG("request: %s\n", req.body.c_str()); - SRV_DBG("response: %s\n", res.body.c_str()); -} - -static void res_err(httplib::Response & res, const json & error_data) { - json final_response {{"error", error_data}}; - res.set_content(safe_json_to_str(final_response), MIMETYPE_JSON); - res.status = json_value(error_data, "code", 500); -} - -static void res_ok(httplib::Response & res, const json & data) { - res.set_content(safe_json_to_str(data), MIMETYPE_JSON); - res.status = 200; -} - -std::function shutdown_handler; -std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; - -inline void signal_handler(int signal) { - if (is_terminating.test_and_set()) { - // in case it hangs, we can force terminate the server by hitting Ctrl+C twice - // this is for better developer experience, we can remove when the server is stable enough - fprintf(stderr, "Received second interrupt, terminating immediately.\n"); - exit(1); + void error(const json & error_data) { + status = json_value(error_data, "code", 500); + data = safe_json_to_str({{ "error", error_data }}); } +}; - shutdown_handler(signal); -} +struct server_routes { + const common_params & params; + server_context & ctx_server; + server_http_context & ctx_http; // for reading is_ready + server_routes(const common_params & params, server_context & ctx_server, server_http_context & ctx_http) + : params(params), ctx_server(ctx_server), ctx_http(ctx_http) {} -int main(int argc, char ** argv) { - // own arguments required by this example - common_params params; +public: + // handlers using lambda function, so that they can capture `this` without `std::bind` - if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) { - return 1; - } - - // TODO: should we have a separate n_parallel parameter for the server? - // https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177 - // TODO: this is a common configuration that is suitable for most local use cases - // however, overriding the parameters is a bit confusing - figure out something more intuitive - if (params.n_parallel == 1 && params.kv_unified == false && !params.has_speculative()) { - LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true (add -kvu to disable this)\n", __func__); - - params.n_parallel = 4; - params.kv_unified = true; - } - - common_init(); - - // struct that contains llama context and inference - server_context ctx_server; - - llama_backend_init(); - llama_numa_init(params.numa); - - LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency()); - LOG_INF("\n"); - LOG_INF("%s\n", common_params_get_system_info(params).c_str()); - LOG_INF("\n"); - - std::unique_ptr svr; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if (params.ssl_file_key != "" && params.ssl_file_cert != "") { - LOG_INF("Running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str()); - svr.reset( - new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str()) - ); - } else { - LOG_INF("Running without SSL\n"); - svr.reset(new httplib::Server()); - } -#else - if (params.ssl_file_key != "" && params.ssl_file_cert != "") { - LOG_ERR("Server is built without SSL support\n"); - return 1; - } - svr.reset(new httplib::Server()); -#endif - - std::atomic state{SERVER_STATE_LOADING_MODEL}; - - svr->set_default_headers({{"Server", "llama.cpp"}}); - svr->set_logger(log_server_request); - svr->set_exception_handler([](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) { - std::string message; - try { - std::rethrow_exception(ep); - } catch (const std::exception & e) { - message = e.what(); - } catch (...) { - message = "Unknown Exception"; - } - - try { - json formatted_error = format_error_response(message, ERROR_TYPE_SERVER); - LOG_WRN("got exception: %s\n", formatted_error.dump().c_str()); - res_err(res, formatted_error); - } catch (const std::exception & e) { - LOG_ERR("got another exception: %s | while hanlding exception: %s\n", e.what(), message.c_str()); - } - }); - - svr->set_error_handler([](const httplib::Request &, httplib::Response & res) { - if (res.status == 404) { - res_err(res, format_error_response("File Not Found", ERROR_TYPE_NOT_FOUND)); - } - // for other error codes, we skip processing here because it's already done by res_err() - }); - - // set timeouts and change hostname and port - svr->set_read_timeout (params.timeout_read); - svr->set_write_timeout(params.timeout_write); - - std::unordered_map log_data; - - log_data["hostname"] = params.hostname; - log_data["port"] = std::to_string(params.port); - - if (params.api_keys.size() == 1) { - auto key = params.api_keys[0]; - log_data["api_key"] = "api_key: ****" + key.substr(std::max((int)(key.length() - 4), 0)); - } else if (params.api_keys.size() > 1) { - log_data["api_key"] = "api_key: " + std::to_string(params.api_keys.size()) + " keys loaded"; - } - - // Necessary similarity of prompt for slot selection - ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; - - // - // Middlewares - // - - auto middleware_validate_api_key = [¶ms](const httplib::Request & req, httplib::Response & res) { - static const std::unordered_set public_endpoints = { - "/health", - "/v1/health", - "/models", - "/v1/models", - "/api/tags" - }; - - // If API key is not set, skip validation - if (params.api_keys.empty()) { - return true; - } - - // If path is public or is static file, skip validation - if (public_endpoints.find(req.path) != public_endpoints.end() || req.path == "/") { - return true; - } - - // Check for API key in the header - auto auth_header = req.get_header_value("Authorization"); - - std::string prefix = "Bearer "; - if (auth_header.substr(0, prefix.size()) == prefix) { - std::string received_api_key = auth_header.substr(prefix.size()); - if (std::find(params.api_keys.begin(), params.api_keys.end(), received_api_key) != params.api_keys.end()) { - return true; // API key is valid - } - } - - // API key is invalid or not provided - res_err(res, format_error_response("Invalid API Key", ERROR_TYPE_AUTHENTICATION)); - - LOG_WRN("Unauthorized: Invalid API Key\n"); - - return false; - }; - - auto middleware_server_state = [&state](const httplib::Request & req, httplib::Response & res) { - server_state current_state = state.load(); - if (current_state == SERVER_STATE_LOADING_MODEL) { - auto tmp = string_split(req.path, '.'); - if (req.path == "/" || tmp.back() == "html") { - res.set_content(reinterpret_cast(loading_html), loading_html_len, "text/html; charset=utf-8"); - res.status = 503; - } else if (req.path == "/models" || req.path == "/v1/models" || req.path == "/api/tags") { - // allow the models endpoint to be accessed during loading - return true; - } else { - res_err(res, format_error_response("Loading model", ERROR_TYPE_UNAVAILABLE)); - } - return false; - } - return true; - }; - - // register server middlewares - svr->set_pre_routing_handler([&middleware_validate_api_key, &middleware_server_state](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - // If this is OPTIONS request, skip validation because browsers don't include Authorization header - if (req.method == "OPTIONS") { - res.set_header("Access-Control-Allow-Credentials", "true"); - res.set_header("Access-Control-Allow-Methods", "GET, POST"); - res.set_header("Access-Control-Allow-Headers", "*"); - res.set_content("", "text/html"); // blank response, no data - return httplib::Server::HandlerResponse::Handled; // skip further processing - } - if (!middleware_server_state(req, res)) { - return httplib::Server::HandlerResponse::Handled; - } - if (!middleware_validate_api_key(req, res)) { - return httplib::Server::HandlerResponse::Handled; - } - return httplib::Server::HandlerResponse::Unhandled; - }); - - // - // Route handlers (or controllers) - // - - const auto handle_health = [&](const httplib::Request &, httplib::Response & res) { + server_http_context::handler_t get_health = [this](const server_http_req &) { // error and loading states are handled by middleware - json health = {{"status", "ok"}}; - res_ok(res, health); + auto res = std::make_unique(ctx_server); + res->ok({{"status", "ok"}}); + return res; }; - const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) { - if (!params.endpoint_slots) { - res_err(res, format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED)); - return; - } - - // request slots data using task queue - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_METRICS); - task.id = task_id; - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task), true); // high-priority task - } - - // get the result - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res_err(res, result->to_json()); - return; - } - - // TODO: get rid of this dynamic_cast - auto res_task = dynamic_cast(result.get()); - GGML_ASSERT(res_task != nullptr); - - // optionally return "fail_on_no_slot" error - if (req.has_param("fail_on_no_slot")) { - if (res_task->n_idle_slots == 0) { - res_err(res, format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); - return; - } - } - - res_ok(res, res_task->slots_data); - }; - - const auto handle_metrics = [&](const httplib::Request &, httplib::Response & res) { + server_http_context::handler_t get_metrics = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); if (!params.endpoint_metrics) { - res_err(res, format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED)); - return; + res->error(format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED)); + return res; } // request slots data using task queue + // TODO: use server_response_reader int task_id = ctx_server.queue_tasks.get_new_id(); { server_task task(SERVER_TASK_TYPE_METRICS); @@ -4711,8 +4487,8 @@ int main(int argc, char ** argv) { ctx_server.queue_results.remove_waiting_task_id(task_id); if (result->is_error()) { - res_err(res, result->to_json()); - return; + res->error(result->to_json()); + return res; } // TODO: get rid of this dynamic_cast @@ -4786,130 +4562,86 @@ int main(int argc, char ** argv) { } } - res.set_header("Process-Start-Time-Unix", std::to_string(res_task->t_start)); - - res.set_content(prometheus.str(), "text/plain; version=0.0.4"); - res.status = 200; // HTTP OK + res->headers["Process-Start-Time-Unix"] = std::to_string(res_task->t_start); + res->content_type = "text/plain; version=0.0.4"; + res->ok(prometheus.str()); + return res; }; - const auto handle_slots_save = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { - json request_data = json::parse(req.body); - std::string filename = request_data.at("filename"); - if (!fs_validate_filename(filename)) { - res_err(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); - return; + server_http_context::handler_t get_slots = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + if (!params.endpoint_slots) { + res->error(format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED)); + return res; } - std::string filepath = params.slot_save_path + filename; + // request slots data using task queue int task_id = ctx_server.queue_tasks.get_new_id(); { - server_task task(SERVER_TASK_TYPE_SLOT_SAVE); + server_task task(SERVER_TASK_TYPE_METRICS); task.id = task_id; - task.slot_action.slot_id = id_slot; - task.slot_action.filename = filename; - task.slot_action.filepath = filepath; - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); + ctx_server.queue_tasks.post(std::move(task), true); // high-priority task } + // get the result server_task_result_ptr result = ctx_server.queue_results.recv(task_id); ctx_server.queue_results.remove_waiting_task_id(task_id); if (result->is_error()) { - res_err(res, result->to_json()); - return; + res->error(result->to_json()); + return res; } - res_ok(res, result->to_json()); + // TODO: get rid of this dynamic_cast + auto res_task = dynamic_cast(result.get()); + GGML_ASSERT(res_task != nullptr); + + // optionally return "fail_on_no_slot" error + if (!req.get_param("fail_on_no_slot").empty()) { + if (res_task->n_idle_slots == 0) { + res->error(format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); + return res; + } + } + + res->ok(res_task->slots_data); + return res; }; - const auto handle_slots_restore = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res, int id_slot) { - json request_data = json::parse(req.body); - std::string filename = request_data.at("filename"); - if (!fs_validate_filename(filename)) { - res_err(res, format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); - return; - } - std::string filepath = params.slot_save_path + filename; - - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_SLOT_RESTORE); - task.id = task_id; - task.slot_action.slot_id = id_slot; - task.slot_action.filename = filename; - task.slot_action.filepath = filepath; - - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); - } - - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res_err(res, result->to_json()); - return; - } - - GGML_ASSERT(dynamic_cast(result.get()) != nullptr); - res_ok(res, result->to_json()); - }; - - const auto handle_slots_erase = [&ctx_server](const httplib::Request & /* req */, httplib::Response & res, int id_slot) { - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_SLOT_ERASE); - task.id = task_id; - task.slot_action.slot_id = id_slot; - - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); - } - - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res_err(res, result->to_json()); - return; - } - - GGML_ASSERT(dynamic_cast(result.get()) != nullptr); - res_ok(res, result->to_json()); - }; - - const auto handle_slots_action = [¶ms, &handle_slots_save, &handle_slots_restore, &handle_slots_erase](const httplib::Request & req, httplib::Response & res) { + server_http_context::handler_t post_slots = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); if (params.slot_save_path.empty()) { - res_err(res, format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED)); - return; + res->error(format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED)); + return res; } - std::string id_slot_str = req.path_params.at("id_slot"); + std::string id_slot_str = req.get_param("id_slot"); int id_slot; try { id_slot = std::stoi(id_slot_str); } catch (const std::exception &) { - res_err(res, format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST)); - return; + res->error(format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST)); + return res; } - std::string action = req.get_param_value("action"); + std::string action = req.get_param("action"); if (action == "save") { - handle_slots_save(req, res, id_slot); + return handle_slots_save(req, id_slot); } else if (action == "restore") { - handle_slots_restore(req, res, id_slot); + return handle_slots_restore(req, id_slot); } else if (action == "erase") { - handle_slots_erase(req, res, id_slot); + return handle_slots_erase(req, id_slot); } else { - res_err(res, format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST)); + res->error(format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST)); + return res; } }; - const auto handle_props = [¶ms, &ctx_server](const httplib::Request &, httplib::Response & res) { + server_http_context::handler_t get_props = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); json default_generation_settings_for_props; { @@ -4948,23 +4680,24 @@ int main(int argc, char ** argv) { } } - res_ok(res, data); + res->ok(data); + return res; }; - const auto handle_props_change = [&ctx_server](const httplib::Request & req, httplib::Response & res) { - if (!ctx_server.params_base.endpoint_props) { - res_err(res, format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED)); - return; + server_http_context::handler_t post_props = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); + if (!params.endpoint_props) { + res->error(format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED)); + return res; } - - json data = json::parse(req.body); - // update any props here - res_ok(res, {{ "success", true }}); + res->ok({{ "success", true }}); + return res; }; - const auto handle_api_show = [&ctx_server](const httplib::Request &, httplib::Response & res) { + server_http_context::handler_t get_api_show = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); bool has_mtmd = ctx_server.mctx != nullptr; json data = { { @@ -4990,24 +4723,404 @@ int main(int argc, char ** argv) { {"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})} }; - res_ok(res, data); + res->ok(data); + return res; }; - // handle completion-like requests (completion, chat, infill) - // we can optionally provide a custom format for partial results and final results - const auto handle_completions_impl = [&ctx_server]( - server_task_type type, - json & data, - const std::vector & files, - const std::function & is_connection_closed, - httplib::Response & res, - oaicompat_type oaicompat) -> void { + server_http_context::handler_t post_infill = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + // check model compatibility + std::string err; + if (llama_vocab_fim_pre(ctx_server.vocab) == LLAMA_TOKEN_NULL) { + err += "prefix token is missing. "; + } + if (llama_vocab_fim_suf(ctx_server.vocab) == LLAMA_TOKEN_NULL) { + err += "suffix token is missing. "; + } + if (llama_vocab_fim_mid(ctx_server.vocab) == LLAMA_TOKEN_NULL) { + err += "middle token is missing. "; + } + if (!err.empty()) { + res->error(format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED)); + return res; + } + + // validate input + json data = json::parse(req.body); + if (data.contains("prompt") && !data.at("prompt").is_string()) { + // prompt is optional + res->error(format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST)); + } + + if (!data.contains("input_prefix")) { + res->error(format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST)); + } + + if (!data.contains("input_suffix")) { + res->error(format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST)); + } + + if (data.contains("input_extra") && !data.at("input_extra").is_array()) { + // input_extra is optional + res->error(format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + json input_extra = json_value(data, "input_extra", json::array()); + for (const auto & chunk : input_extra) { + // { "text": string, "filename": string } + if (!chunk.contains("text") || !chunk.at("text").is_string()) { + res->error(format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + // filename is optional + if (chunk.contains("filename") && !chunk.at("filename").is_string()) { + res->error(format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + } + data["input_extra"] = input_extra; // default to empty array if it's not exist + + std::string prompt = json_value(data, "prompt", std::string()); + std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, false, true); + SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size()); + data["prompt"] = format_infill( + ctx_server.vocab, + data.at("input_prefix"), + data.at("input_suffix"), + data.at("input_extra"), + ctx_server.params_base.n_batch, + ctx_server.params_base.n_predict, + ctx_server.slots[0].n_ctx, // TODO: there should be a better way + ctx_server.params_base.spm_infill, + tokenized_prompts[0].get_text_tokens() // TODO: this could maybe be multimodal. + ); + + std::vector files; // dummy + return handle_completions_impl( + SERVER_TASK_TYPE_INFILL, + data, + files, + req.should_stop, + OAICOMPAT_TYPE_NONE); // infill is not OAI compatible + }; + + server_http_context::handler_t post_completions = [this](const server_http_req & req) { + std::vector files; // dummy + const json body = json::parse(req.body); + return handle_completions_impl( + SERVER_TASK_TYPE_COMPLETION, + body, + files, + req.should_stop, + OAICOMPAT_TYPE_NONE); + }; + + server_http_context::handler_t post_completions_oai = [this](const server_http_req & req) { + std::vector files; // dummy + const json body = json::parse(req.body); + return handle_completions_impl( + SERVER_TASK_TYPE_COMPLETION, + body, + files, + req.should_stop, + OAICOMPAT_TYPE_COMPLETION); + }; + + server_http_context::handler_t post_chat_completions = [this](const server_http_req & req) { + std::vector files; + json body = json::parse(req.body); + json body_parsed = oaicompat_chat_params_parse( + body, + ctx_server.oai_parser_opt, + files); + return handle_completions_impl( + SERVER_TASK_TYPE_COMPLETION, + body_parsed, + files, + req.should_stop, + OAICOMPAT_TYPE_CHAT); + }; + + // same with handle_chat_completions, but without inference part + server_http_context::handler_t post_apply_template = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + std::vector files; // dummy, unused + json body = json::parse(req.body); + json data = oaicompat_chat_params_parse( + body, + ctx_server.oai_parser_opt, + files); + res->ok({{ "prompt", std::move(data.at("prompt")) }}); + return res; + }; + + server_http_context::handler_t get_models = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); + bool is_model_ready = ctx_http.is_ready.load(); + json model_meta = nullptr; + if (is_model_ready) { + model_meta = ctx_server.model_meta(); + } + bool has_mtmd = ctx_server.mctx != nullptr; + json models = { + {"models", { + { + {"name", params.model_alias.empty() ? params.model.path : params.model_alias}, + {"model", params.model_alias.empty() ? params.model.path : params.model_alias}, + {"modified_at", ""}, + {"size", ""}, + {"digest", ""}, // dummy value, llama.cpp does not support managing model file's hash + {"type", "model"}, + {"description", ""}, + {"tags", {""}}, + {"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})}, + {"parameters", ""}, + {"details", { + {"parent_model", ""}, + {"format", "gguf"}, + {"family", ""}, + {"families", {""}}, + {"parameter_size", ""}, + {"quantization_level", ""} + }} + } + }}, + {"object", "list"}, + {"data", { + { + {"id", params.model_alias.empty() ? params.model.path : params.model_alias}, + {"object", "model"}, + {"created", std::time(0)}, + {"owned_by", "llamacpp"}, + {"meta", model_meta}, + }, + }} + }; + + res->ok(models); + return res; + }; + + server_http_context::handler_t post_tokenize = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + const json body = json::parse(req.body); + json tokens_response = json::array(); + if (body.count("content") != 0) { + const bool add_special = json_value(body, "add_special", false); + const bool parse_special = json_value(body, "parse_special", true); + const bool with_pieces = json_value(body, "with_pieces", false); + + llama_tokens tokens = tokenize_mixed(ctx_server.vocab, body.at("content"), add_special, parse_special); + + if (with_pieces) { + for (const auto& token : tokens) { + std::string piece = common_token_to_piece(ctx_server.ctx, token); + json piece_json; + + // Check if the piece is valid UTF-8 + if (is_valid_utf8(piece)) { + piece_json = piece; + } else { + // If not valid UTF-8, store as array of byte values + piece_json = json::array(); + for (unsigned char c : piece) { + piece_json.push_back(static_cast(c)); + } + } + + tokens_response.push_back({ + {"id", token}, + {"piece", piece_json} + }); + } + } else { + tokens_response = tokens; + } + } + + const json data = format_tokenizer_response(tokens_response); + res->ok(data); + return res; + }; + + server_http_context::handler_t post_detokenize = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + const json body = json::parse(req.body); + + std::string content; + if (body.count("tokens") != 0) { + const llama_tokens tokens = body.at("tokens"); + content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend()); + } + + const json data = format_detokenized_response(content); + res->ok(data); + return res; + }; + + server_http_context::handler_t post_embeddings = [this](const server_http_req & req) { + return handle_embeddings_impl(req, OAICOMPAT_TYPE_NONE); + }; + + server_http_context::handler_t post_embeddings_oai = [this](const server_http_req & req) { + return handle_embeddings_impl(req, OAICOMPAT_TYPE_EMBEDDING); + }; + + server_http_context::handler_t post_rerank = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) { + res->error(format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); + return res; + } + + const json body = json::parse(req.body); + + // if true, use TEI API format, otherwise use Jina API format + // Jina: https://jina.ai/reranker/ + // TEI: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/rerank + bool is_tei_format = body.contains("texts"); + + json query; + if (body.count("query") == 1) { + query = body.at("query"); + if (!query.is_string()) { + res->error(format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + } else { + res->error(format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + std::vector documents = json_value(body, "documents", + json_value(body, "texts", std::vector())); + if (documents.empty()) { + res->error(format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + int top_n = json_value(body, "top_n", (int)documents.size()); + + // create and queue the task + json responses = json::array(); + server_response_reader rd(ctx_server); + { + std::vector tasks; + tasks.reserve(documents.size()); + for (size_t i = 0; i < documents.size(); i++) { + auto tmp = format_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]); + server_task task = server_task(SERVER_TASK_TYPE_RERANK); + task.id = ctx_server.queue_tasks.get_new_id(); + task.index = i; + task.tokens = std::move(tmp); + tasks.push_back(std::move(task)); + } + rd.post_tasks(std::move(tasks)); + } + + // wait for the results + auto all_results = rd.wait_for_all(req.should_stop); + + // collect results + if (all_results.is_terminated) { + return res; // connection is closed + } else if (all_results.error) { + res->error(all_results.error->to_json()); + return res; + } else { + for (auto & res : all_results.results) { + GGML_ASSERT(dynamic_cast(res.get()) != nullptr); + responses.push_back(res->to_json()); + } + } + + // write JSON response + json root = format_response_rerank( + body, + responses, + is_tei_format, + documents, + top_n); + + res->ok(root); + return res; + }; + + server_http_context::handler_t get_lora_adapters = [this](const server_http_req &) { + auto res = std::make_unique(ctx_server); + json result = json::array(); + const auto & loras = ctx_server.params_base.lora_adapters; + for (size_t i = 0; i < loras.size(); ++i) { + auto & lora = loras[i]; + json entry = { + {"id", i}, + {"path", lora.path}, + {"scale", lora.scale}, + {"task_name", lora.task_name}, + {"prompt_prefix", lora.prompt_prefix}, + }; + std::string alora_invocation_string = ""; + const uint64_t n_alora_tokens = llama_adapter_get_alora_n_invocation_tokens(lora.ptr); + std::vector alora_invocation_tokens; + if (n_alora_tokens) { + const llama_token * alora_tokens = llama_adapter_get_alora_invocation_tokens(lora.ptr); + for (uint64_t i = 0; i < n_alora_tokens; ++i) { + alora_invocation_string += common_token_to_piece(ctx_server.ctx, alora_tokens[i]); + alora_invocation_tokens.push_back(alora_tokens[i]); + } + entry["alora_invocation_string"] = alora_invocation_string; + entry["alora_invocation_tokens"] = alora_invocation_tokens; + } + result.push_back(std::move(entry)); + } + res->ok(result); + return res; + }; + + server_http_context::handler_t post_lora_adapters = [this](const server_http_req & req) { + auto res = std::make_unique(ctx_server); + const json body = json::parse(req.body); + if (!body.is_array()) { + res->error(format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_SET_LORA); + task.id = task_id; + task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body); + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task)); + } + + // get the result + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res->error(result->to_json()); + return res; + } + + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res->ok(result->to_json()); + return res; + }; + +private: + std::unique_ptr handle_completions_impl( + server_task_type type, + const json & data, + const std::vector & files, + const std::function & should_stop, + oaicompat_type oaicompat) { GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); + auto res = std::make_unique(ctx_server); auto completion_id = gen_chatcmplid(); - // need to store the reader as a pointer, so that it won't be destroyed when the handle returns - // use shared_ptr as it's shared between the chunked_content_provider() and on_complete() - const auto rd = std::make_shared(ctx_server); + auto & rd = res->rd; try { std::vector tasks; @@ -5048,22 +5161,22 @@ int main(int argc, char ** argv) { tasks.push_back(std::move(task)); } - rd->post_tasks(std::move(tasks)); + rd.post_tasks(std::move(tasks)); } catch (const std::exception & e) { - res_err(res, format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); - return; + res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); + return res; } bool stream = json_value(data, "stream", false); if (!stream) { // non-stream, wait for the results - auto all_results = rd->wait_for_all(is_connection_closed); + auto all_results = rd.wait_for_all(should_stop); if (all_results.is_terminated) { - return; // connection is closed + return res; // connection is closed } else if (all_results.error) { - res_err(res, all_results.error->to_json()); - return; + res->error(all_results.error->to_json()); + return res; } else { json arr = json::array(); for (auto & res : all_results.results) { @@ -5071,19 +5184,19 @@ int main(int argc, char ** argv) { arr.push_back(res->to_json()); } // if single request, return single object instead of array - res_ok(res, arr.size() == 1 ? arr[0] : arr); + res->ok(arr.size() == 1 ? arr[0] : arr); } } else { // in streaming mode, the first error must be treated as non-stream response // this is to match the OAI API behavior // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309 - server_task_result_ptr first_result = rd->next(is_connection_closed); + server_task_result_ptr first_result = rd.next(should_stop); if (first_result == nullptr) { - return; // connection is closed + return res; // connection is closed } else if (first_result->is_error()) { - res_err(res, first_result->to_json()); - return; + res->error(first_result->to_json()); + return res; } else { GGML_ASSERT( dynamic_cast(first_result.get()) != nullptr @@ -5092,307 +5205,171 @@ int main(int argc, char ** argv) { } // next responses are streamed - json first_result_json = first_result->to_json(); - const auto chunked_content_provider = [first_result_json, rd, oaicompat](size_t, httplib::DataSink & sink) mutable -> bool { - // flush the first result as it's not an error - if (!first_result_json.empty()) { - if (!server_sent_event(sink, first_result_json)) { - sink.done(); - return false; // sending failed, go to on_complete() + res->data = format_sse(first_result->to_json()); // to be sent immediately + res->status = 200; + res->content_type = "text/event-stream"; + res->next = [res_this = res.get(), oaicompat, &should_stop](std::string & output) -> bool { + if (should_stop()) { + SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); + return false; // should_stop condition met + } + + if (!res_this->data.empty()) { + // flush the first chunk + output = std::move(res_this->data); + res_this->data.clear(); + return true; + } + + server_response_reader & rd = res_this->rd; + + // check if there is more data + if (!rd.has_next()) { + if (oaicompat != OAICOMPAT_TYPE_NONE) { + output = "data: [DONE]\n\n"; + } else { + output = ""; } - first_result_json.clear(); // mark as sent + SRV_DBG("%s", "all results received, terminating stream\n"); + return false; // no more data, terminate } // receive subsequent results - auto result = rd->next([&sink]{ return !sink.is_writable(); }); + auto result = rd.next(should_stop); if (result == nullptr) { - sink.done(); - return false; // connection is closed, go to on_complete() + SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); + return false; // should_stop condition met } // send the results json res_json = result->to_json(); - bool ok = false; if (result->is_error()) { - ok = server_sent_event(sink, json {{ "error", result->to_json() }}); - sink.done(); - return false; // go to on_complete() + output = format_sse(json {{ "error", res_json }}); + SRV_DBG("%s", "error received during streaming, terminating stream\n"); + return false; // terminate on error } else { GGML_ASSERT( dynamic_cast(result.get()) != nullptr || dynamic_cast(result.get()) != nullptr ); - ok = server_sent_event(sink, res_json); - } - - if (!ok) { - sink.done(); - return false; // sending failed, go to on_complete() - } - - // check if there is more data - if (!rd->has_next()) { - if (oaicompat != OAICOMPAT_TYPE_NONE) { - static const std::string ev_done = "data: [DONE]\n\n"; - sink.write(ev_done.data(), ev_done.size()); - } - sink.done(); - return false; // no more data, go to on_complete() + output = format_sse(res_json); } // has next data, continue return true; }; - - auto on_complete = [rd](bool) { - rd->stop(); - }; - - res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete); - } - }; - - const auto handle_completions = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) { - json data = json::parse(req.body); - std::vector files; // dummy - handle_completions_impl( - SERVER_TASK_TYPE_COMPLETION, - data, - files, - req.is_connection_closed, - res, - OAICOMPAT_TYPE_NONE); - }; - - const auto handle_completions_oai = [&handle_completions_impl](const httplib::Request & req, httplib::Response & res) { - json data = oaicompat_completion_params_parse(json::parse(req.body)); - std::vector files; // dummy - handle_completions_impl( - SERVER_TASK_TYPE_COMPLETION, - data, - files, - req.is_connection_closed, - res, - OAICOMPAT_TYPE_COMPLETION); - }; - - const auto handle_infill = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { - // check model compatibility - std::string err; - if (llama_vocab_fim_pre(ctx_server.vocab) == LLAMA_TOKEN_NULL) { - err += "prefix token is missing. "; - } - if (llama_vocab_fim_suf(ctx_server.vocab) == LLAMA_TOKEN_NULL) { - err += "suffix token is missing. "; - } - if (llama_vocab_fim_mid(ctx_server.vocab) == LLAMA_TOKEN_NULL) { - err += "middle token is missing. "; - } - if (!err.empty()) { - res_err(res, format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED)); - return; } - json data = json::parse(req.body); + return res; + } - // validate input - if (data.contains("prompt") && !data.at("prompt").is_string()) { - // prompt is optional - res_err(res, format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST)); + std::unique_ptr handle_slots_save(const server_http_req & req, int id_slot) { + auto res = std::make_unique(ctx_server); + const json request_data = json::parse(req.body); + std::string filename = request_data.at("filename"); + if (!fs_validate_filename(filename)) { + res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + std::string filepath = params.slot_save_path + filename; + + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_SLOT_SAVE); + task.id = task_id; + task.slot_action.slot_id = id_slot; + task.slot_action.filename = filename; + task.slot_action.filepath = filepath; + + // TODO: use server_response_reader + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task)); } - if (!data.contains("input_prefix")) { - res_err(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST)); + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res->error(result->to_json()); + return res; } - if (!data.contains("input_suffix")) { - res_err(res, format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST)); + res->ok(result->to_json()); + return res; + } + + std::unique_ptr handle_slots_restore(const server_http_req & req, int id_slot) { + auto res = std::make_unique(ctx_server); + const json request_data = json::parse(req.body); + std::string filename = request_data.at("filename"); + if (!fs_validate_filename(filename)) { + res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); + return res; + } + std::string filepath = params.slot_save_path + filename; + + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_SLOT_RESTORE); + task.id = task_id; + task.slot_action.slot_id = id_slot; + task.slot_action.filename = filename; + task.slot_action.filepath = filepath; + + // TODO: use server_response_reader + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task)); } - if (data.contains("input_extra") && !data.at("input_extra").is_array()) { - // input_extra is optional - res_err(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST)); - return; + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); + + if (result->is_error()) { + res->error(result->to_json()); + return res; } - json input_extra = json_value(data, "input_extra", json::array()); - for (const auto & chunk : input_extra) { - // { "text": string, "filename": string } - if (!chunk.contains("text") || !chunk.at("text").is_string()) { - res_err(res, format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST)); - return; - } - // filename is optional - if (chunk.contains("filename") && !chunk.at("filename").is_string()) { - res_err(res, format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST)); - return; - } - } - data["input_extra"] = input_extra; // default to empty array if it's not exist + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res->ok(result->to_json()); + return res; + } - std::string prompt = json_value(data, "prompt", std::string()); - std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, false, true); - SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size()); - data["prompt"] = format_infill( - ctx_server.vocab, - data.at("input_prefix"), - data.at("input_suffix"), - data.at("input_extra"), - ctx_server.params_base.n_batch, - ctx_server.params_base.n_predict, - ctx_server.slots[0].n_ctx, // TODO: there should be a better way - ctx_server.params_base.spm_infill, - tokenized_prompts[0].get_text_tokens() // TODO: this could maybe be multimodal. - ); + std::unique_ptr handle_slots_erase(const server_http_req &, int id_slot) { + auto res = std::make_unique(ctx_server); + int task_id = ctx_server.queue_tasks.get_new_id(); + { + server_task task(SERVER_TASK_TYPE_SLOT_ERASE); + task.id = task_id; + task.slot_action.slot_id = id_slot; - std::vector files; // dummy - handle_completions_impl( - SERVER_TASK_TYPE_INFILL, - data, - files, - req.is_connection_closed, - res, - OAICOMPAT_TYPE_NONE); // infill is not OAI compatible - }; - - const auto handle_chat_completions = [&ctx_server, &handle_completions_impl](const httplib::Request & req, httplib::Response & res) { - LOG_DBG("request: %s\n", req.body.c_str()); - - auto body = json::parse(req.body); - std::vector files; - json data = oaicompat_chat_params_parse( - body, - ctx_server.oai_parser_opt, - files); - - handle_completions_impl( - SERVER_TASK_TYPE_COMPLETION, - data, - files, - req.is_connection_closed, - res, - OAICOMPAT_TYPE_CHAT); - }; - - // same with handle_chat_completions, but without inference part - const auto handle_apply_template = [&ctx_server](const httplib::Request & req, httplib::Response & res) { - auto body = json::parse(req.body); - std::vector files; // dummy, unused - json data = oaicompat_chat_params_parse( - body, - ctx_server.oai_parser_opt, - files); - res_ok(res, {{ "prompt", std::move(data.at("prompt")) }}); - }; - - const auto handle_models = [¶ms, &ctx_server, &state](const httplib::Request &, httplib::Response & res) { - server_state current_state = state.load(); - json model_meta = nullptr; - if (current_state == SERVER_STATE_READY) { - model_meta = ctx_server.model_meta(); - } - bool has_mtmd = ctx_server.mctx != nullptr; - json models = { - {"models", { - { - {"name", params.model_alias.empty() ? params.model.path : params.model_alias}, - {"model", params.model_alias.empty() ? params.model.path : params.model_alias}, - {"modified_at", ""}, - {"size", ""}, - {"digest", ""}, // dummy value, llama.cpp does not support managing model file's hash - {"type", "model"}, - {"description", ""}, - {"tags", {""}}, - {"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})}, - {"parameters", ""}, - {"details", { - {"parent_model", ""}, - {"format", "gguf"}, - {"family", ""}, - {"families", {""}}, - {"parameter_size", ""}, - {"quantization_level", ""} - }} - } - }}, - {"object", "list"}, - {"data", { - { - {"id", params.model_alias.empty() ? params.model.path : params.model_alias}, - {"object", "model"}, - {"created", std::time(0)}, - {"owned_by", "llamacpp"}, - {"meta", model_meta}, - }, - }} - }; - - res_ok(res, models); - }; - - const auto handle_tokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { - const json body = json::parse(req.body); - - json tokens_response = json::array(); - if (body.count("content") != 0) { - const bool add_special = json_value(body, "add_special", false); - const bool parse_special = json_value(body, "parse_special", true); - const bool with_pieces = json_value(body, "with_pieces", false); - - llama_tokens tokens = tokenize_mixed(ctx_server.vocab, body.at("content"), add_special, parse_special); - - if (with_pieces) { - for (const auto& token : tokens) { - std::string piece = common_token_to_piece(ctx_server.ctx, token); - json piece_json; - - // Check if the piece is valid UTF-8 - if (is_valid_utf8(piece)) { - piece_json = piece; - } else { - // If not valid UTF-8, store as array of byte values - piece_json = json::array(); - for (unsigned char c : piece) { - piece_json.push_back(static_cast(c)); - } - } - - tokens_response.push_back({ - {"id", token}, - {"piece", piece_json} - }); - } - } else { - tokens_response = tokens; - } + // TODO: use server_response_reader + ctx_server.queue_results.add_waiting_task_id(task_id); + ctx_server.queue_tasks.post(std::move(task)); } - const json data = format_tokenizer_response(tokens_response); - res_ok(res, data); - }; + server_task_result_ptr result = ctx_server.queue_results.recv(task_id); + ctx_server.queue_results.remove_waiting_task_id(task_id); - const auto handle_detokenize = [&ctx_server](const httplib::Request & req, httplib::Response & res) { - const json body = json::parse(req.body); - - std::string content; - if (body.count("tokens") != 0) { - const llama_tokens tokens = body.at("tokens"); - content = tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend()); + if (result->is_error()) { + res->error(result->to_json()); + return res; } - const json data = format_detokenized_response(content); - res_ok(res, data); - }; + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res->ok(result->to_json()); + return res; + } - const auto handle_embeddings_impl = [&ctx_server](const httplib::Request & req, httplib::Response & res, oaicompat_type oaicompat) { + std::unique_ptr handle_embeddings_impl(const server_http_req & req, oaicompat_type oaicompat) { + auto res = std::make_unique(ctx_server); if (!ctx_server.params_base.embedding) { - res_err(res, format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); - return; + res->error(format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); + return res; } if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { - res_err(res, format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); - return; + res->error(format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); + return res; } const json body = json::parse(req.body); @@ -5405,8 +5382,8 @@ int main(int argc, char ** argv) { oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible prompt = body.at("content"); } else { - res_err(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); - return; + res->error(format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); + return res; } bool use_base64 = false; @@ -5415,8 +5392,8 @@ int main(int argc, char ** argv) { if (format == "base64") { use_base64 = true; } else if (format != "float") { - res_err(res, format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST)); - return; + res->error(format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST)); + return res; } } @@ -5424,8 +5401,8 @@ int main(int argc, char ** argv) { for (const auto & tokens : tokenized_prompts) { // this check is necessary for models that do not add BOS token to the input if (tokens.empty()) { - res_err(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST)); - return; + res->error(format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST)); + return res; } } @@ -5459,14 +5436,14 @@ int main(int argc, char ** argv) { } // wait for the results - auto all_results = rd.wait_for_all(req.is_connection_closed); + auto all_results = rd.wait_for_all(req.should_stop); // collect results if (all_results.is_terminated) { - return; // connection is closed + return res; // connection is closed } else if (all_results.error) { - res_err(res, all_results.error->to_json()); - return; + res->error(all_results.error->to_json()); + return res; } else { for (auto & res : all_results.results) { GGML_ASSERT(dynamic_cast(res.get()) != nullptr); @@ -5478,292 +5455,170 @@ int main(int argc, char ** argv) { json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING ? format_embeddings_response_oaicompat(body, responses, use_base64) : json(responses); - res_ok(res, root); + res->ok(root); + return res; + } +}; + +std::function shutdown_handler; +std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; + +inline void signal_handler(int signal) { + if (is_terminating.test_and_set()) { + // in case it hangs, we can force terminate the server by hitting Ctrl+C twice + // this is for better developer experience, we can remove when the server is stable enough + fprintf(stderr, "Received second interrupt, terminating immediately.\n"); + exit(1); + } + + shutdown_handler(signal); +} + +// wrapper function that handles exceptions and logs errors +// this is to make sure handler_t never throws exceptions; instead, it returns an error response +static server_http_context::handler_t ex_wrapper(server_http_context::handler_t func) { + return [func = std::move(func)](const server_http_req & req) -> server_http_res_ptr { + std::string message; + try { + return func(req); + } catch (const std::exception & e) { + message = e.what(); + } catch (...) { + message = "unknown error"; + } + + auto res = std::make_unique(); + res->status = 500; + try { + json error_data = format_error_response(message, ERROR_TYPE_SERVER); + res->status = json_value(error_data, "code", 500); + res->data = safe_json_to_str({{ "error", error_data }}); + LOG_WRN("got exception: %s\n", res->data.c_str()); + } catch (const std::exception & e) { + LOG_ERR("got another exception: %s | while hanlding exception: %s\n", e.what(), message.c_str()); + res->data = "Internal Server Error"; + } + return res; }; +} - const auto handle_embeddings = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) { - handle_embeddings_impl(req, res, OAICOMPAT_TYPE_NONE); - }; +int main(int argc, char ** argv) { + // own arguments required by this example + common_params params; - const auto handle_embeddings_oai = [&handle_embeddings_impl](const httplib::Request & req, httplib::Response & res) { - handle_embeddings_impl(req, res, OAICOMPAT_TYPE_EMBEDDING); - }; + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) { + return 1; + } - const auto handle_rerank = [&ctx_server](const httplib::Request & req, httplib::Response & res) { - if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) { - res_err(res, format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); - return; - } + // TODO: should we have a separate n_parallel parameter for the server? + // https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177 + // TODO: this is a common configuration that is suitable for most local use cases + // however, overriding the parameters is a bit confusing - figure out something more intuitive + if (params.n_parallel == 1 && params.kv_unified == false && !params.has_speculative()) { + LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true (add -kvu to disable this)\n", __func__); - const json body = json::parse(req.body); + params.n_parallel = 4; + params.kv_unified = true; + } - // if true, use TEI API format, otherwise use Jina API format - // Jina: https://jina.ai/reranker/ - // TEI: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/rerank - bool is_tei_format = body.contains("texts"); + common_init(); - json query; - if (body.count("query") == 1) { - query = body.at("query"); - if (!query.is_string()) { - res_err(res, format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST)); - return; - } - } else { - res_err(res, format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST)); - return; - } + // struct that contains llama context and inference + server_context ctx_server; - std::vector documents = json_value(body, "documents", - json_value(body, "texts", std::vector())); - if (documents.empty()) { - res_err(res, format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); - return; - } + // Necessary similarity of prompt for slot selection + ctx_server.slot_prompt_similarity = params.slot_prompt_similarity; - int top_n = json_value(body, "top_n", (int)documents.size()); + llama_backend_init(); + llama_numa_init(params.numa); - // create and queue the task - json responses = json::array(); - server_response_reader rd(ctx_server); - { - std::vector tasks; - tasks.reserve(documents.size()); - for (size_t i = 0; i < documents.size(); i++) { - auto tmp = format_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]); - server_task task = server_task(SERVER_TASK_TYPE_RERANK); - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - task.tokens = std::move(tmp); - tasks.push_back(std::move(task)); - } - rd.post_tasks(std::move(tasks)); - } + LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency()); + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + LOG_INF("\n"); - // wait for the results - auto all_results = rd.wait_for_all(req.is_connection_closed); - - // collect results - if (all_results.is_terminated) { - return; // connection is closed - } else if (all_results.error) { - res_err(res, all_results.error->to_json()); - return; - } else { - for (auto & res : all_results.results) { - GGML_ASSERT(dynamic_cast(res.get()) != nullptr); - responses.push_back(res->to_json()); - } - } - - // write JSON response - json root = format_response_rerank( - body, - responses, - is_tei_format, - documents, - top_n); - - res_ok(res, root); - }; - - const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) { - json result = json::array(); - const auto & loras = ctx_server.params_base.lora_adapters; - for (size_t i = 0; i < loras.size(); ++i) { - auto & lora = loras[i]; - json entry = { - {"id", i}, - {"path", lora.path}, - {"scale", lora.scale}, - {"task_name", lora.task_name}, - {"prompt_prefix", lora.prompt_prefix}, - }; - std::string alora_invocation_string = ""; - const uint64_t n_alora_tokens = llama_adapter_get_alora_n_invocation_tokens(lora.ptr); - std::vector alora_invocation_tokens; - if (n_alora_tokens) { - const llama_token * alora_tokens = llama_adapter_get_alora_invocation_tokens(lora.ptr); - for (uint64_t i = 0; i < n_alora_tokens; ++i) { - alora_invocation_string += common_token_to_piece(ctx_server.ctx, alora_tokens[i]); - alora_invocation_tokens.push_back(alora_tokens[i]); - } - entry["alora_invocation_string"] = alora_invocation_string; - entry["alora_invocation_tokens"] = alora_invocation_tokens; - } - result.push_back(std::move(entry)); - } - res_ok(res, result); - res.status = 200; // HTTP OK - }; - - const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) { - const json body = json::parse(req.body); - if (!body.is_array()) { - res_err(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST)); - return; - } - - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_SET_LORA); - task.id = task_id; - task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body); - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); - } - - // get the result - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res_err(res, result->to_json()); - return; - } - - GGML_ASSERT(dynamic_cast(result.get()) != nullptr); - res_ok(res, result->to_json()); - }; + server_http_context ctx_http; + if (!ctx_http.init(params)) { + LOG_ERR("%s: failed to initialize HTTP server\n", __func__); + return 1; + } // // Router // - if (!params.webui) { - LOG_INF("Web UI is disabled\n"); - } else { - // register static assets routes - if (!params.public_path.empty()) { - // Set the base directory for serving static files - bool is_found = svr->set_mount_point(params.api_prefix + "/", params.public_path); - if (!is_found) { - LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str()); - return 1; - } - } else { - // using embedded static index.html - svr->Get(params.api_prefix + "/", [](const httplib::Request & req, httplib::Response & res) { - if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) { - res.set_content("Error: gzip is not supported by this browser", "text/plain"); - } else { - res.set_header("Content-Encoding", "gzip"); - // COEP and COOP headers, required by pyodide (python interpreter) - res.set_header("Cross-Origin-Embedder-Policy", "require-corp"); - res.set_header("Cross-Origin-Opener-Policy", "same-origin"); - res.set_content(reinterpret_cast(index_html_gz), index_html_gz_len, "text/html; charset=utf-8"); - } - return false; - }); - } - } - // register API routes - svr->Get (params.api_prefix + "/health", handle_health); // public endpoint (no API key check) - svr->Get (params.api_prefix + "/v1/health", handle_health); // public endpoint (no API key check) - svr->Get (params.api_prefix + "/metrics", handle_metrics); - svr->Get (params.api_prefix + "/props", handle_props); - svr->Post(params.api_prefix + "/props", handle_props_change); - svr->Post(params.api_prefix + "/api/show", handle_api_show); - svr->Get (params.api_prefix + "/models", handle_models); // public endpoint (no API key check) - svr->Get (params.api_prefix + "/v1/models", handle_models); // public endpoint (no API key check) - svr->Get (params.api_prefix + "/api/tags", handle_models); // ollama specific endpoint. public endpoint (no API key check) - svr->Post(params.api_prefix + "/completion", handle_completions); // legacy - svr->Post(params.api_prefix + "/completions", handle_completions); - svr->Post(params.api_prefix + "/v1/completions", handle_completions_oai); - svr->Post(params.api_prefix + "/chat/completions", handle_chat_completions); - svr->Post(params.api_prefix + "/v1/chat/completions", handle_chat_completions); - svr->Post(params.api_prefix + "/api/chat", handle_chat_completions); // ollama specific endpoint - svr->Post(params.api_prefix + "/infill", handle_infill); - svr->Post(params.api_prefix + "/embedding", handle_embeddings); // legacy - svr->Post(params.api_prefix + "/embeddings", handle_embeddings); - svr->Post(params.api_prefix + "/v1/embeddings", handle_embeddings_oai); - svr->Post(params.api_prefix + "/rerank", handle_rerank); - svr->Post(params.api_prefix + "/reranking", handle_rerank); - svr->Post(params.api_prefix + "/v1/rerank", handle_rerank); - svr->Post(params.api_prefix + "/v1/reranking", handle_rerank); - svr->Post(params.api_prefix + "/tokenize", handle_tokenize); - svr->Post(params.api_prefix + "/detokenize", handle_detokenize); - svr->Post(params.api_prefix + "/apply-template", handle_apply_template); + server_routes routes(params, ctx_server, ctx_http); + + ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) + ctx_http.get ("/v1/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) + ctx_http.get ("/metrics", ex_wrapper(routes.get_metrics)); + ctx_http.get ("/props", ex_wrapper(routes.get_props)); + ctx_http.post("/props", ex_wrapper(routes.post_props)); + ctx_http.post("/api/show", ex_wrapper(routes.get_api_show)); + ctx_http.get ("/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check) + ctx_http.get ("/v1/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check) + ctx_http.get ("/api/tags", ex_wrapper(routes.get_models)); // ollama specific endpoint. public endpoint (no API key check) + ctx_http.post("/completion", ex_wrapper(routes.post_completions)); // legacy + ctx_http.post("/completions", ex_wrapper(routes.post_completions)); + ctx_http.post("/v1/completions", ex_wrapper(routes.post_completions_oai)); + ctx_http.post("/chat/completions", ex_wrapper(routes.post_chat_completions)); + ctx_http.post("/v1/chat/completions", ex_wrapper(routes.post_chat_completions)); + ctx_http.post("/api/chat", ex_wrapper(routes.post_chat_completions)); // ollama specific endpoint + ctx_http.post("/infill", ex_wrapper(routes.post_infill)); + ctx_http.post("/embedding", ex_wrapper(routes.post_embeddings)); // legacy + ctx_http.post("/embeddings", ex_wrapper(routes.post_embeddings)); + ctx_http.post("/v1/embeddings", ex_wrapper(routes.post_embeddings_oai)); + ctx_http.post("/rerank", ex_wrapper(routes.post_rerank)); + ctx_http.post("/reranking", ex_wrapper(routes.post_rerank)); + ctx_http.post("/v1/rerank", ex_wrapper(routes.post_rerank)); + ctx_http.post("/v1/reranking", ex_wrapper(routes.post_rerank)); + ctx_http.post("/tokenize", ex_wrapper(routes.post_tokenize)); + ctx_http.post("/detokenize", ex_wrapper(routes.post_detokenize)); + ctx_http.post("/apply-template", ex_wrapper(routes.post_apply_template)); // LoRA adapters hotswap - svr->Get (params.api_prefix + "/lora-adapters", handle_lora_adapters_list); - svr->Post(params.api_prefix + "/lora-adapters", handle_lora_adapters_apply); + ctx_http.get ("/lora-adapters", ex_wrapper(routes.get_lora_adapters)); + ctx_http.post("/lora-adapters", ex_wrapper(routes.post_lora_adapters)); // Save & load slots - svr->Get (params.api_prefix + "/slots", handle_slots); - svr->Post(params.api_prefix + "/slots/:id_slot", handle_slots_action); + ctx_http.get ("/slots", ex_wrapper(routes.get_slots)); + ctx_http.post("/slots/:id_slot", ex_wrapper(routes.post_slots)); // // Start the server // - if (params.n_threads_http < 1) { - // +2 threads for monitoring endpoints - params.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1); - } - log_data["n_threads_http"] = std::to_string(params.n_threads_http); - svr->new_task_queue = [¶ms] { return new httplib::ThreadPool(params.n_threads_http); }; - // clean up function, to be called before exit - auto clean_up = [&svr, &ctx_server]() { + // setup clean up function, to be called before exit + auto clean_up = [&ctx_http, &ctx_server]() { SRV_INF("%s: cleaning up before exit...\n", __func__); - svr->stop(); + ctx_http.stop(); ctx_server.queue_results.terminate(); llama_backend_free(); }; - bool was_bound = false; - bool is_sock = false; - if (string_ends_with(std::string(params.hostname), ".sock")) { - is_sock = true; - LOG_INF("%s: setting address family to AF_UNIX\n", __func__); - svr->set_address_family(AF_UNIX); - // bind_to_port requires a second arg, any value other than 0 should - // simply get ignored - was_bound = svr->bind_to_port(params.hostname, 8080); - } else { - LOG_INF("%s: binding port with default address family\n", __func__); - // bind HTTP listen port - if (params.port == 0) { - int bound_port = svr->bind_to_any_port(params.hostname); - if ((was_bound = (bound_port >= 0))) { - params.port = bound_port; - } - } else { - was_bound = svr->bind_to_port(params.hostname, params.port); - } - } - - if (!was_bound) { - LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, params.hostname.c_str(), params.port); + // start the HTTP server before loading the model to be able to serve /health requests + if (!ctx_http.start()) { clean_up(); + LOG_ERR("%s: exiting due to HTTP server error\n", __func__); return 1; } - // run the HTTP server in a thread - std::thread t([&]() { svr->listen_after_bind(); }); - svr->wait_until_ready(); - - LOG_INF("%s: HTTP server is listening, hostname: %s, port: %d, http threads: %d\n", __func__, params.hostname.c_str(), params.port, params.n_threads_http); - // load the model LOG_INF("%s: loading model\n", __func__); if (!ctx_server.load_model(params)) { clean_up(); - t.join(); + if (ctx_http.thread.joinable()) { + ctx_http.thread.join(); + } LOG_ERR("%s: exiting due to model loading error\n", __func__); return 1; } ctx_server.init(); - state.store(SERVER_STATE_READY); + ctx_http.is_ready.store(true); LOG_INF("%s: model loaded\n", __func__); - // print sample chat example to make it clear which template is used - LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, - common_chat_templates_source(ctx_server.chat_templates.get()), - common_chat_format_example(ctx_server.chat_templates.get(), ctx_server.params_base.use_jinja, ctx_server.params_base.default_template_kwargs).c_str()); - ctx_server.queue_tasks.on_new_task([&ctx_server](server_task && task) { ctx_server.process_single_task(std::move(task)); }); @@ -5791,15 +5646,15 @@ int main(int argc, char ** argv) { SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); #endif - LOG_INF("%s: server is listening on %s - starting the main loop\n", __func__, - is_sock ? string_format("unix://%s", params.hostname.c_str()).c_str() : - string_format("http://%s:%d", params.hostname.c_str(), params.port).c_str()); - + LOG_INF("%s: server is listening on %s\n", __func__, ctx_http.listening_address.c_str()); + LOG_INF("%s: starting the main loop...\n", __func__); // this call blocks the main thread until queue_tasks.terminate() is called ctx_server.queue_tasks.start_loop(); clean_up(); - t.join(); + if (ctx_http.thread.joinable()) { + ctx_http.thread.join(); + } llama_memory_breakdown_print(ctx_server.ctx); return 0; diff --git a/tools/server/utils.hpp b/tools/server/utils.hpp index b1ecc5af5e..bf21726051 100644 --- a/tools/server/utils.hpp +++ b/tools/server/utils.hpp @@ -9,8 +9,6 @@ #include "mtmd-helper.h" #include "chat.h" -#include - #define JSON_ASSERT GGML_ASSERT #include @@ -426,6 +424,10 @@ static std::string gen_tool_call_id() { // other common utils // +static std::string safe_json_to_str(const json & data) { + return data.dump(-1, ' ', false, json::error_handler_t::replace); +} + // TODO: reuse llama_detokenize template static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { @@ -453,29 +455,25 @@ static std::string tokens_to_output_formatted_string(const llama_context * ctx, return out; } +// format server-sent event (SSE), return the formatted string to send // note: if data is a json array, it will be sent as multiple events, one per item -static bool server_sent_event(httplib::DataSink & sink, const json & data) { - static auto send_single = [](httplib::DataSink & sink, const json & data) -> bool { - const std::string str = - "data: " + - data.dump(-1, ' ', false, json::error_handler_t::replace) + +static std::string format_sse(const json & data) { + std::ostringstream ss; + auto send_single = [&ss](const json & data) { + ss << "data: " << + safe_json_to_str(data) << "\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row). - - LOG_DBG("data stream, to_send: %s", str.c_str()); - return sink.write(str.c_str(), str.size()); }; if (data.is_array()) { for (const auto & item : data) { - if (!send_single(sink, item)) { - return false; - } + send_single(item); } } else { - return send_single(sink, data); + send_single(data); } - return true; + return ss.str(); } // @@ -954,10 +952,6 @@ static json format_logit_bias(const std::vector & logit_bias) return data; } -static std::string safe_json_to_str(const json & data) { - return data.dump(-1, ' ', false, json::error_handler_t::replace); -} - static std::vector get_token_probabilities(llama_context * ctx, int idx) { std::vector cur; const auto * logits = llama_get_logits_ith(ctx, idx);