mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	server : better security control for public deployments (#9776)
* server : more explicit endpoint access settings * protect /props endpoint * fix tests * update server docs * fix typo * fix tests
This commit is contained in:
		| @@ -1106,12 +1106,7 @@ struct server_context { | ||||
|         SRV_DBG("system prompt set: '%s'\n", system_prompt.c_str()); | ||||
|  | ||||
|         system_prompt = sys_prompt; | ||||
|  | ||||
|         // release all slots | ||||
|         for (server_slot & slot : slots) { | ||||
|             slot.release(); | ||||
|         } | ||||
|  | ||||
|         // update system_tokens and KV cache as soon as all slots are idle | ||||
|         system_need_update = true; | ||||
|         return true; | ||||
|     } | ||||
| @@ -1627,16 +1622,6 @@ struct server_context { | ||||
|                         break; | ||||
|                     } | ||||
|  | ||||
|                     if (task.data.contains("system_prompt")) { | ||||
|                         std::string sys_prompt = json_value(task.data, "system_prompt", std::string()); | ||||
|                         system_prompt_set(sys_prompt); | ||||
|  | ||||
|                         for (server_slot & slot : slots) { | ||||
|                             slot.n_past    = 0; | ||||
|                             slot.n_past_se = 0; | ||||
|                         } | ||||
|                     } | ||||
|  | ||||
|                     slot->reset(); | ||||
|  | ||||
|                     slot->id_task   = task.id; | ||||
| @@ -1862,10 +1847,6 @@ struct server_context { | ||||
|     } | ||||
|  | ||||
|     void update_slots() { | ||||
|         if (system_need_update) { | ||||
|             system_prompt_update(); | ||||
|         } | ||||
|  | ||||
|         // check if all slots are idle | ||||
|         { | ||||
|             bool all_idle = true; | ||||
| @@ -1878,6 +1859,10 @@ struct server_context { | ||||
|             } | ||||
|  | ||||
|             if (all_idle) { | ||||
|                 if (system_need_update) { | ||||
|                     system_prompt_update(); | ||||
|                 } | ||||
|  | ||||
|                 SRV_INF("%s", "all slots are idle\n"); | ||||
|                 if (system_prompt.empty() && clean_kv_cache) { | ||||
|                     kv_cache_clear(); | ||||
| @@ -2536,20 +2521,10 @@ int main(int argc, char ** argv) { | ||||
|     // | ||||
|  | ||||
|     auto middleware_validate_api_key = [¶ms, &res_error](const httplib::Request & req, httplib::Response & res) { | ||||
|         // TODO: should we apply API key to all endpoints, including "/health" and "/models"? | ||||
|         static const std::unordered_set<std::string> protected_endpoints = { | ||||
|             "/props", | ||||
|             "/completion", | ||||
|             "/completions", | ||||
|             "/v1/completions", | ||||
|             "/chat/completions", | ||||
|             "/v1/chat/completions", | ||||
|             "/infill", | ||||
|             "/tokenize", | ||||
|             "/detokenize", | ||||
|             "/embedding", | ||||
|             "/embeddings", | ||||
|             "/v1/embeddings", | ||||
|         static const std::unordered_set<std::string> public_endpoints = { | ||||
|             "/health", | ||||
|             "/models", | ||||
|             "/v1/models", | ||||
|         }; | ||||
|  | ||||
|         // If API key is not set, skip validation | ||||
| @@ -2557,8 +2532,8 @@ int main(int argc, char ** argv) { | ||||
|             return true; | ||||
|         } | ||||
|  | ||||
|         // If path is not in protected_endpoints list, skip validation | ||||
|         if (protected_endpoints.find(req.path) == protected_endpoints.end()) { | ||||
|         // If path is public, skip validation | ||||
|         if (public_endpoints.find(req.path) != public_endpoints.end()) { | ||||
|             return true; | ||||
|         } | ||||
|  | ||||
| @@ -2620,7 +2595,7 @@ int main(int argc, char ** argv) { | ||||
|  | ||||
|     const auto handle_slots = [&](const httplib::Request & req, httplib::Response & res) { | ||||
|         if (!params.endpoint_slots) { | ||||
|             res_error(res, format_error_response("This server does not support slots endpoint. Start it without `--no-slots`", ERROR_TYPE_NOT_SUPPORTED)); | ||||
|             res_error(res, format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED)); | ||||
|             return; | ||||
|         } | ||||
|  | ||||
| @@ -2869,24 +2844,31 @@ int main(int argc, char ** argv) { | ||||
|     }; | ||||
|  | ||||
|     const auto handle_props = [&ctx_server, &res_ok](const httplib::Request &, httplib::Response & res) { | ||||
|         std::string template_key = "tokenizer.chat_template", curr_tmpl; | ||||
|         int32_t tlen = llama_model_meta_val_str(ctx_server.model, template_key.c_str(), nullptr, 0); | ||||
|         if (tlen > 0) { | ||||
|             std::vector<char> curr_tmpl_buf(tlen + 1, 0); | ||||
|             if (llama_model_meta_val_str(ctx_server.model, template_key.c_str(), curr_tmpl_buf.data(), curr_tmpl_buf.size()) == tlen) { | ||||
|                 curr_tmpl = std::string(curr_tmpl_buf.data(), tlen); | ||||
|             } | ||||
|         } | ||||
|         json data = { | ||||
|             { "system_prompt",               ctx_server.system_prompt.c_str() }, | ||||
|             { "system_prompt",               ctx_server.system_prompt }, | ||||
|             { "default_generation_settings", ctx_server.default_generation_settings_for_props }, | ||||
|             { "total_slots",                 ctx_server.params.n_parallel }, | ||||
|             { "chat_template",               curr_tmpl.c_str() }, | ||||
|             { "chat_template",               llama_get_chat_template(ctx_server.model) }, | ||||
|         }; | ||||
|  | ||||
|         res_ok(res, data); | ||||
|     }; | ||||
|  | ||||
|     const auto handle_props_change = [&ctx_server, &res_error, &res_ok](const httplib::Request & req, httplib::Response & res) { | ||||
|         if (!ctx_server.params.endpoint_props) { | ||||
|             res_error(res, format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED)); | ||||
|             return; | ||||
|         } | ||||
|  | ||||
|         json data = json::parse(req.body); | ||||
|         if (data.contains("system_prompt")) { | ||||
|             std::string system_prompt = data.at("system_prompt"); | ||||
|             ctx_server.system_prompt_set(system_prompt); | ||||
|         } | ||||
|  | ||||
|         res_ok(res, {{ "success", true }}); | ||||
|     }; | ||||
|  | ||||
|     const auto handle_completions_generic = [&ctx_server, &res_error, &res_ok](server_task_cmpl_type cmpl_type, json & data, httplib::Response & res) { | ||||
|         if (ctx_server.params.embedding || ctx_server.params.reranking) { | ||||
|             res_error(res, format_error_response("This server does not support completions. Start it without `--embeddings` or `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); | ||||
| @@ -3265,30 +3247,39 @@ int main(int argc, char ** argv) { | ||||
|         svr->set_base_dir(params.public_path); | ||||
|     } | ||||
|  | ||||
|     // using embedded static files | ||||
|     svr->Get("/",                           handle_static_file(index_html, index_html_len, "text/html; charset=utf-8")); | ||||
|     svr->Get("/index.js",                   handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8")); | ||||
|     svr->Get("/completion.js",              handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8")); | ||||
|     svr->Get("/json-schema-to-grammar.mjs", handle_static_file(json_schema_to_grammar_mjs, json_schema_to_grammar_mjs_len, "text/javascript; charset=utf-8")); | ||||
|     if (!params.api_keys.empty()) { | ||||
|         // for now, if API key is set, web UI is unusable | ||||
|         svr->Get("/", [&](const httplib::Request &, httplib::Response & res) { | ||||
|             return res.set_content("Web UI is disabled because API key is set.", "text/html; charset=utf-8"); | ||||
|         }); | ||||
|     } else { | ||||
|         // using embedded static files | ||||
|         svr->Get("/",                           handle_static_file(index_html, index_html_len, "text/html; charset=utf-8")); | ||||
|         svr->Get("/index.js",                   handle_static_file(index_js, index_js_len, "text/javascript; charset=utf-8")); | ||||
|         svr->Get("/completion.js",              handle_static_file(completion_js, completion_js_len, "text/javascript; charset=utf-8")); | ||||
|         svr->Get("/json-schema-to-grammar.mjs", handle_static_file(json_schema_to_grammar_mjs, json_schema_to_grammar_mjs_len, "text/javascript; charset=utf-8")); | ||||
|  | ||||
|     // add new-ui files | ||||
|     svr->Get("/colorthemes.css",       handle_static_file(colorthemes_css, colorthemes_css_len, "text/css; charset=utf-8")); | ||||
|     svr->Get("/style.css",             handle_static_file(style_css, style_css_len, "text/css; charset=utf-8")); | ||||
|     svr->Get("/theme-beeninorder.css", handle_static_file(theme_beeninorder_css, theme_beeninorder_css_len, "text/css; charset=utf-8")); | ||||
|     svr->Get("/theme-ketivah.css",     handle_static_file(theme_ketivah_css, theme_ketivah_css_len, "text/css; charset=utf-8")); | ||||
|     svr->Get("/theme-mangotango.css",  handle_static_file(theme_mangotango_css, theme_mangotango_css_len, "text/css; charset=utf-8")); | ||||
|     svr->Get("/theme-playground.css",  handle_static_file(theme_playground_css, theme_playground_css_len, "text/css; charset=utf-8")); | ||||
|     svr->Get("/theme-polarnight.css",  handle_static_file(theme_polarnight_css, theme_polarnight_css_len, "text/css; charset=utf-8")); | ||||
|     svr->Get("/theme-snowstorm.css",   handle_static_file(theme_snowstorm_css, theme_snowstorm_css_len, "text/css; charset=utf-8")); | ||||
|     svr->Get("/index-new.html",        handle_static_file(index_new_html, index_new_html_len, "text/html; charset=utf-8")); | ||||
|     svr->Get("/system-prompts.js",     handle_static_file(system_prompts_js, system_prompts_js_len, "text/javascript; charset=utf-8")); | ||||
|     svr->Get("/prompt-formats.js",     handle_static_file(prompt_formats_js, prompt_formats_js_len, "text/javascript; charset=utf-8")); | ||||
|         // add new-ui files | ||||
|         svr->Get("/colorthemes.css",       handle_static_file(colorthemes_css, colorthemes_css_len, "text/css; charset=utf-8")); | ||||
|         svr->Get("/style.css",             handle_static_file(style_css, style_css_len, "text/css; charset=utf-8")); | ||||
|         svr->Get("/theme-beeninorder.css", handle_static_file(theme_beeninorder_css, theme_beeninorder_css_len, "text/css; charset=utf-8")); | ||||
|         svr->Get("/theme-ketivah.css",     handle_static_file(theme_ketivah_css, theme_ketivah_css_len, "text/css; charset=utf-8")); | ||||
|         svr->Get("/theme-mangotango.css",  handle_static_file(theme_mangotango_css, theme_mangotango_css_len, "text/css; charset=utf-8")); | ||||
|         svr->Get("/theme-playground.css",  handle_static_file(theme_playground_css, theme_playground_css_len, "text/css; charset=utf-8")); | ||||
|         svr->Get("/theme-polarnight.css",  handle_static_file(theme_polarnight_css, theme_polarnight_css_len, "text/css; charset=utf-8")); | ||||
|         svr->Get("/theme-snowstorm.css",   handle_static_file(theme_snowstorm_css, theme_snowstorm_css_len, "text/css; charset=utf-8")); | ||||
|         svr->Get("/index-new.html",        handle_static_file(index_new_html, index_new_html_len, "text/html; charset=utf-8")); | ||||
|         svr->Get("/system-prompts.js",     handle_static_file(system_prompts_js, system_prompts_js_len, "text/javascript; charset=utf-8")); | ||||
|         svr->Get("/prompt-formats.js",     handle_static_file(prompt_formats_js, prompt_formats_js_len, "text/javascript; charset=utf-8")); | ||||
|     } | ||||
|  | ||||
|     // register API routes | ||||
|     svr->Get ("/health",              handle_health); | ||||
|     svr->Get ("/health",              handle_health); // public endpoint (no API key check) | ||||
|     svr->Get ("/metrics",             handle_metrics); | ||||
|     svr->Get ("/props",               handle_props); | ||||
|     svr->Get ("/v1/models",           handle_models); | ||||
|     svr->Post("/props",               handle_props_change); | ||||
|     svr->Get ("/models",              handle_models); // public endpoint (no API key check) | ||||
|     svr->Get ("/v1/models",           handle_models); // public endpoint (no API key check) | ||||
|     svr->Post("/completion",          handle_completions); // legacy | ||||
|     svr->Post("/completions",         handle_completions); | ||||
|     svr->Post("/v1/completions",      handle_completions); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Xuan Son Nguyen
					Xuan Son Nguyen