mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	server : fix format_infill (#10724)
* server : fix format_infill * fix * rename * update test * use another model * update test * update test * test_invalid_input_extra_req
This commit is contained in:
		| @@ -3484,6 +3484,11 @@ int main(int argc, char ** argv) { | |||||||
|         json data = json::parse(req.body); |         json data = json::parse(req.body); | ||||||
|  |  | ||||||
|         // validate input |         // validate input | ||||||
|  |         if (data.contains("prompt") && !data.at("prompt").is_string()) { | ||||||
|  |             // prompt is optional | ||||||
|  |             res_error(res, format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST)); | ||||||
|  |         } | ||||||
|  |  | ||||||
|         if (!data.contains("input_prefix")) { |         if (!data.contains("input_prefix")) { | ||||||
|             res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST)); |             res_error(res, format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST)); | ||||||
|         } |         } | ||||||
| @@ -3493,9 +3498,11 @@ int main(int argc, char ** argv) { | |||||||
|         } |         } | ||||||
|  |  | ||||||
|         if (data.contains("input_extra") && !data.at("input_extra").is_array()) { |         if (data.contains("input_extra") && !data.at("input_extra").is_array()) { | ||||||
|  |             // input_extra is optional | ||||||
|             res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST)); |             res_error(res, format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST)); | ||||||
|             return; |             return; | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         json input_extra = json_value(data, "input_extra", json::array()); |         json input_extra = json_value(data, "input_extra", json::array()); | ||||||
|         for (const auto & chunk : input_extra) { |         for (const auto & chunk : input_extra) { | ||||||
|             // { "text": string, "filename": string } |             // { "text": string, "filename": string } | ||||||
| @@ -3511,6 +3518,21 @@ int main(int argc, char ** argv) { | |||||||
|         } |         } | ||||||
|         data["input_extra"] = input_extra; // default to empty array if it's not exist |         data["input_extra"] = input_extra; // default to empty array if it's not exist | ||||||
|  |  | ||||||
|  |         std::string prompt = json_value(data, "prompt", std::string()); | ||||||
|  |         std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true); | ||||||
|  |         SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size()); | ||||||
|  |         data["prompt"] = format_infill( | ||||||
|  |             ctx_server.ctx, | ||||||
|  |             data.at("input_prefix"), | ||||||
|  |             data.at("input_suffix"), | ||||||
|  |             data.at("input_extra"), | ||||||
|  |             ctx_server.params_base.n_batch, | ||||||
|  |             ctx_server.params_base.n_predict, | ||||||
|  |             ctx_server.slots[0].n_ctx, // TODO: there should be a better way | ||||||
|  |             ctx_server.params_base.spm_infill, | ||||||
|  |             tokenized_prompts[0] | ||||||
|  |         ); | ||||||
|  |  | ||||||
|         return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res); |         return handle_completions_generic(SERVER_TASK_TYPE_INFILL, data, res); | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -13,28 +13,28 @@ def test_infill_without_input_extra(): | |||||||
|     global server |     global server | ||||||
|     server.start() |     server.start() | ||||||
|     res = server.make_request("POST", "/infill", data={ |     res = server.make_request("POST", "/infill", data={ | ||||||
|         "prompt": "Complete this", |         "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n", | ||||||
|         "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n    int n_threads = llama_", |         "prompt": "    int n_threads = llama_", | ||||||
|         "input_suffix": "}\n", |         "input_suffix": "}\n", | ||||||
|     }) |     }) | ||||||
|     assert res.status_code == 200 |     assert res.status_code == 200 | ||||||
|     assert match_regex("(One|day|she|saw|big|scary|bird)+", res.body["content"]) |     assert match_regex("(Ann|small|shiny)+", res.body["content"]) | ||||||
|  |  | ||||||
|  |  | ||||||
| def test_infill_with_input_extra(): | def test_infill_with_input_extra(): | ||||||
|     global server |     global server | ||||||
|     server.start() |     server.start() | ||||||
|     res = server.make_request("POST", "/infill", data={ |     res = server.make_request("POST", "/infill", data={ | ||||||
|         "prompt": "Complete this", |  | ||||||
|         "input_extra": [{ |         "input_extra": [{ | ||||||
|             "filename": "llama.h", |             "filename": "llama.h", | ||||||
|             "text": "LLAMA_API int32_t llama_n_threads();\n" |             "text": "LLAMA_API int32_t llama_n_threads();\n" | ||||||
|         }], |         }], | ||||||
|         "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n    int n_threads = llama_", |         "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n", | ||||||
|  |         "prompt": "    int n_threads = llama_", | ||||||
|         "input_suffix": "}\n", |         "input_suffix": "}\n", | ||||||
|     }) |     }) | ||||||
|     assert res.status_code == 200 |     assert res.status_code == 200 | ||||||
|     assert match_regex("(cuts|Jimmy|mom|came|into|the|room)+", res.body["content"]) |     assert match_regex("(Dad|excited|park)+", res.body["content"]) | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.parametrize("input_extra", [ | @pytest.mark.parametrize("input_extra", [ | ||||||
| @@ -48,10 +48,30 @@ def test_invalid_input_extra_req(input_extra): | |||||||
|     global server |     global server | ||||||
|     server.start() |     server.start() | ||||||
|     res = server.make_request("POST", "/infill", data={ |     res = server.make_request("POST", "/infill", data={ | ||||||
|         "prompt": "Complete this", |  | ||||||
|         "input_extra": [input_extra], |         "input_extra": [input_extra], | ||||||
|         "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n    int n_threads = llama_", |         "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n", | ||||||
|  |         "prompt": "    int n_threads = llama_", | ||||||
|         "input_suffix": "}\n", |         "input_suffix": "}\n", | ||||||
|     }) |     }) | ||||||
|     assert res.status_code == 400 |     assert res.status_code == 400 | ||||||
|     assert "error" in res.body |     assert "error" in res.body | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.skipif(not is_slow_test_allowed(), reason="skipping slow test") | ||||||
|  | def test_with_qwen_model(): | ||||||
|  |     global server | ||||||
|  |     server.model_file = None | ||||||
|  |     server.model_hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-IQ3_XXS-GGUF" | ||||||
|  |     server.model_hf_file = "qwen2.5-coder-1.5b-iq3_xxs-imat.gguf" | ||||||
|  |     server.start(timeout_seconds=600) | ||||||
|  |     res = server.make_request("POST", "/infill", data={ | ||||||
|  |         "input_extra": [{ | ||||||
|  |             "filename": "llama.h", | ||||||
|  |             "text": "LLAMA_API int32_t llama_n_threads();\n" | ||||||
|  |         }], | ||||||
|  |         "input_prefix": "#include <cstdio>\n#include \"llama.h\"\n\nint main() {\n", | ||||||
|  |         "prompt": "    int n_threads = llama_", | ||||||
|  |         "input_suffix": "}\n", | ||||||
|  |     }) | ||||||
|  |     assert res.status_code == 200 | ||||||
|  |     assert res.body["content"] == "n_threads();\n    printf(\"Number of threads: %d\\n\", n_threads);\n    return 0;\n" | ||||||
|   | |||||||
| @@ -371,3 +371,6 @@ def match_regex(regex: str, text: str) -> bool: | |||||||
|         ).search(text) |         ).search(text) | ||||||
|         is not None |         is not None | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|  | def is_slow_test_allowed(): | ||||||
|  |     return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON" | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user