#include "utils.hpp" #include "common.h" #include "server-http.h" #include #include #include #include // auto generated files (see README.md for details) #include "index.html.gz.hpp" #include "loading.html.hpp" // // HTTP implementation using cpp-httplib // class server_http_context::Impl { public: std::unique_ptr srv; }; server_http_context::server_http_context() : pimpl(std::make_unique()) {} server_http_context::~server_http_context() = default; static void log_server_request(const httplib::Request & req, const httplib::Response & res) { // skip GH copilot requests when using default port if (req.path == "/v1/health") { return; } // reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch SRV_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status); SRV_DBG("request: %s\n", req.body.c_str()); SRV_DBG("response: %s\n", res.body.c_str()); } bool server_http_context::init(const common_params & params) { path_prefix = params.api_prefix; port = params.port; hostname = params.hostname; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT if (params.ssl_file_key != "" && params.ssl_file_cert != "") { LOG_INF("Running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str()); svr.reset( new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str()) ); } else { LOG_INF("Running without SSL\n"); svr.reset(new httplib::Server()); } #else if (params.ssl_file_key != "" && params.ssl_file_cert != "") { LOG_ERR("Server is built without SSL support\n"); return false; } pimpl->srv.reset(new httplib::Server()); #endif auto & srv = pimpl->srv; srv->set_default_headers({{"Server", "llama.cpp"}}); srv->set_logger(log_server_request); srv->set_exception_handler([](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) { // this is fail-safe; exceptions should already handled by `ex_wrapper` std::string message; try { std::rethrow_exception(ep); } catch (const std::exception & e) { message = e.what(); } catch (...) { message = "Unknown Exception"; } res.status = 500; res.set_content(message, "text/plain"); LOG_ERR("got exception: %s\n", message.c_str()); }); srv->set_error_handler([](const httplib::Request &, httplib::Response & res) { if (res.status == 404) { res.set_content( safe_json_to_str(json { {"error", { {"message", "File Not Found"}, {"type", "not_found_error"}, {"code", 404} }} }), "application/json; charset=utf-8" ); } // for other error codes, we skip processing here because it's already done by res->error() }); // set timeouts and change hostname and port srv->set_read_timeout (params.timeout_read); srv->set_write_timeout(params.timeout_write); if (params.api_keys.size() == 1) { auto key = params.api_keys[0]; std::string substr = key.substr(std::max((int)(key.length() - 4), 0)); LOG_INF("%s: api_keys: ****%s\n", __func__, substr.c_str()); } else if (params.api_keys.size() > 1) { LOG_INF("%s: api_keys: %zu keys loaded\n", __func__, params.api_keys.size()); } // // Middlewares // auto middleware_validate_api_key = [api_keys = params.api_keys](const httplib::Request & req, httplib::Response & res) { static const std::unordered_set public_endpoints = { "/health", "/v1/health", "/models", "/v1/models", "/api/tags" }; // If API key is not set, skip validation if (api_keys.empty()) { return true; } // If path is public or is static file, skip validation if (public_endpoints.find(req.path) != public_endpoints.end() || req.path == "/") { return true; } // Check for API key in the header auto auth_header = req.get_header_value("Authorization"); std::string prefix = "Bearer "; if (auth_header.substr(0, prefix.size()) == prefix) { std::string received_api_key = auth_header.substr(prefix.size()); if (std::find(api_keys.begin(), api_keys.end(), received_api_key) != api_keys.end()) { return true; // API key is valid } } // API key is invalid or not provided res.status = 401; res.set_content( safe_json_to_str(json { {"error", { {"message", "Invalid API Key"}, {"type", "authentication_error"}, {"code", 401} }} }), "application/json; charset=utf-8" ); LOG_WRN("Unauthorized: Invalid API Key\n"); return false; }; auto middleware_server_state = [this](const httplib::Request & req, httplib::Response & res) { bool ready = is_ready.load(); if (!ready) { auto tmp = string_split(req.path, '.'); if (req.path == "/" || tmp.back() == "html") { res.set_content(reinterpret_cast(loading_html), loading_html_len, "text/html; charset=utf-8"); res.status = 503; } else if (req.path == "/models" || req.path == "/v1/models" || req.path == "/api/tags") { // allow the models endpoint to be accessed during loading return true; } else { res.status = 503; res.set_content( safe_json_to_str(json { {"error", { {"message", "Loading model"}, {"type", "unavailable_error"}, {"code", 503} }} }), "application/json; charset=utf-8" ); } return false; } return true; }; // register server middlewares srv->set_pre_routing_handler([middleware_validate_api_key, middleware_server_state](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); // If this is OPTIONS request, skip validation because browsers don't include Authorization header if (req.method == "OPTIONS") { res.set_header("Access-Control-Allow-Credentials", "true"); res.set_header("Access-Control-Allow-Methods", "GET, POST"); res.set_header("Access-Control-Allow-Headers", "*"); res.set_content("", "text/html"); // blank response, no data return httplib::Server::HandlerResponse::Handled; // skip further processing } if (!middleware_server_state(req, res)) { return httplib::Server::HandlerResponse::Handled; } if (!middleware_validate_api_key(req, res)) { return httplib::Server::HandlerResponse::Handled; } return httplib::Server::HandlerResponse::Unhandled; }); int n_threads_http = params.n_threads_http; if (n_threads_http < 1) { // +2 threads for monitoring endpoints n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1); } LOG_INF("%s: using %d threads for HTTP server\n", __func__, n_threads_http); srv->new_task_queue = [n_threads_http] { return new httplib::ThreadPool(n_threads_http); }; // // Web UI setup // if (!params.webui) { LOG_INF("Web UI is disabled\n"); } else { // register static assets routes if (!params.public_path.empty()) { // Set the base directory for serving static files bool is_found = srv->set_mount_point(params.api_prefix + "/", params.public_path); if (!is_found) { LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str()); return 1; } } else { // using embedded static index.html srv->Get(params.api_prefix + "/", [](const httplib::Request & req, httplib::Response & res) { if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) { res.set_content("Error: gzip is not supported by this browser", "text/plain"); } else { res.set_header("Content-Encoding", "gzip"); // COEP and COOP headers, required by pyodide (python interpreter) res.set_header("Cross-Origin-Embedder-Policy", "require-corp"); res.set_header("Cross-Origin-Opener-Policy", "same-origin"); res.set_content(reinterpret_cast(index_html_gz), index_html_gz_len, "text/html; charset=utf-8"); } return false; }); } } return true; } bool server_http_context::start() { // Bind and listen auto & srv = pimpl->srv; bool was_bound = false; bool is_sock = false; if (string_ends_with(std::string(hostname), ".sock")) { is_sock = true; LOG_INF("%s: setting address family to AF_UNIX\n", __func__); srv->set_address_family(AF_UNIX); // bind_to_port requires a second arg, any value other than 0 should // simply get ignored was_bound = srv->bind_to_port(hostname, 8080); } else { LOG_INF("%s: binding port with default address family\n", __func__); // bind HTTP listen port if (port == 0) { int bound_port = srv->bind_to_any_port(hostname); was_bound = (bound_port >= 0); if (was_bound) { port = bound_port; } } else { was_bound = srv->bind_to_port(hostname, port); } } if (!was_bound) { LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, hostname.c_str(), port); return false; } // run the HTTP server in a thread thread = std::thread([this]() { pimpl->srv->listen_after_bind(); }); srv->wait_until_ready(); listening_address = is_sock ? string_format("unix://%s", hostname.c_str()) : string_format("http://%s:%d", hostname.c_str(), port); return true; } void server_http_context::stop() const { if (pimpl->srv) { pimpl->srv->stop(); } } static void set_headers(httplib::Response & res, const std::map & headers) { for (const auto & [key, value] : headers) { res.set_header(key, value); } } static std::map get_params(const httplib::Request & req) { std::map params; for (const auto & [key, value] : req.params) { params[key] = value; } for (const auto & [key, value] : req.path_params) { params[key] = value; } return params; } static std::map get_headers(const httplib::Request & req) { std::map headers; for (const auto & [key, value] : req.headers) { headers[key] = value; } return headers; } static void process_handler_response(server_http_res_ptr & response, httplib::Response & res) { if (response->is_stream()) { res.status = response->status; set_headers(res, response->headers); std::string content_type = response->content_type; // convert to shared_ptr as both chunked_content_provider() and on_complete() need to use it std::shared_ptr r_ptr = std::move(response); const auto chunked_content_provider = [response = r_ptr](size_t, httplib::DataSink & sink) -> bool { std::string chunk; bool has_next = response->next(chunk); if (!chunk.empty()) { // TODO: maybe handle sink.write unsuccessful? for now, we rely on is_connection_closed() sink.write(chunk.data(), chunk.size()); SRV_DBG("http: streamed chunk: %s\n", chunk.c_str()); } if (!has_next) { sink.done(); SRV_DBG("%s", "http: stream ended\n"); } return has_next; }; const auto on_complete = [response = r_ptr](bool) mutable { response.reset(); // trigger the destruction of the response object }; res.set_chunked_content_provider(content_type, chunked_content_provider, on_complete); } else { res.status = response->status; set_headers(res, response->headers); res.set_content(response->data, response->content_type); } } void server_http_context::get(const std::string & path, const server_http_context::handler_t & handler) const { pimpl->srv->Get(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) { server_http_res_ptr response = handler(server_http_req{ get_params(req), get_headers(req), req.path, req.body, req.is_connection_closed }); process_handler_response(response, res); }); } void server_http_context::post(const std::string & path, const server_http_context::handler_t & handler) const { pimpl->srv->Post(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) { server_http_res_ptr response = handler(server_http_req{ get_params(req), get_headers(req), req.path, req.body, req.is_connection_closed }); process_handler_response(response, res); }); }