mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	server : (refactor) no more json in server_task input (#10691)
* server : (refactor) no more json in server_task input * add test for slots endpoint * add tests for /props and /slots * remove task inf_type * fix CI by adding safe_json_to_str * add "model_path" to /props * update readme
This commit is contained in:
		@@ -687,12 +687,14 @@ This endpoint is public (no API key check). By default, it is read-only. To make
 | 
			
		||||
    }
 | 
			
		||||
  },
 | 
			
		||||
  "total_slots": 1,
 | 
			
		||||
  "model_path": "../models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf",
 | 
			
		||||
  "chat_template": "..."
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
- `default_generation_settings` - the default generation settings for the `/completion` endpoint, which has the same fields as the `generation_settings` response object from the `/completion` endpoint.
 | 
			
		||||
- `total_slots` - the total number of slots for process requests (defined by `--parallel` option)
 | 
			
		||||
- `model_path` - the path to model file (same with `-m` argument)
 | 
			
		||||
- `chat_template` - the model's original Jinja2 prompt template
 | 
			
		||||
 | 
			
		||||
### POST `/props`: Change server global properties.
 | 
			
		||||
 
 | 
			
		||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							@@ -22,7 +22,12 @@ def test_server_props():
 | 
			
		||||
    server.start()
 | 
			
		||||
    res = server.make_request("GET", "/props")
 | 
			
		||||
    assert res.status_code == 200
 | 
			
		||||
    assert ".gguf" in res.body["model_path"]
 | 
			
		||||
    assert res.body["total_slots"] == server.n_slots
 | 
			
		||||
    default_val = res.body["default_generation_settings"]
 | 
			
		||||
    assert server.n_ctx is not None and server.n_slots is not None
 | 
			
		||||
    assert default_val["n_ctx"] == server.n_ctx / server.n_slots
 | 
			
		||||
    assert default_val["params"]["seed"] == server.seed
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_server_models():
 | 
			
		||||
@@ -33,6 +38,31 @@ def test_server_models():
 | 
			
		||||
    assert len(res.body["data"]) == 1
 | 
			
		||||
    assert res.body["data"][0]["id"] == server.model_alias
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_server_slots():
 | 
			
		||||
    global server
 | 
			
		||||
 | 
			
		||||
    # without slots endpoint enabled, this should return error
 | 
			
		||||
    server.server_slots = False
 | 
			
		||||
    server.start()
 | 
			
		||||
    res = server.make_request("GET", "/slots")
 | 
			
		||||
    assert res.status_code == 501 # ERROR_TYPE_NOT_SUPPORTED
 | 
			
		||||
    assert "error" in res.body
 | 
			
		||||
    server.stop()
 | 
			
		||||
 | 
			
		||||
    # with slots endpoint enabled, this should return slots info
 | 
			
		||||
    server.server_slots = True
 | 
			
		||||
    server.n_slots = 2
 | 
			
		||||
    server.start()
 | 
			
		||||
    res = server.make_request("GET", "/slots")
 | 
			
		||||
    assert res.status_code == 200
 | 
			
		||||
    assert len(res.body) == server.n_slots
 | 
			
		||||
    assert server.n_ctx is not None and server.n_slots is not None
 | 
			
		||||
    assert res.body[0]["n_ctx"] == server.n_ctx / server.n_slots
 | 
			
		||||
    assert "params" in res.body[0]
 | 
			
		||||
    assert res.body[0]["params"]["seed"] == server.seed
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_load_split_model():
 | 
			
		||||
    global server
 | 
			
		||||
    server.model_hf_repo = "ggml-org/models"
 | 
			
		||||
 
 | 
			
		||||
@@ -30,6 +30,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
 | 
			
		||||
        ],
 | 
			
		||||
    })
 | 
			
		||||
    assert res.status_code == 200
 | 
			
		||||
    assert "cmpl" in res.body["id"] # make sure the completion id has the expected format
 | 
			
		||||
    assert res.body["model"] == model if model is not None else server.model_alias
 | 
			
		||||
    assert res.body["usage"]["prompt_tokens"] == n_prompt
 | 
			
		||||
    assert res.body["usage"]["completion_tokens"] == n_predicted
 | 
			
		||||
@@ -59,9 +60,13 @@ def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_conte
 | 
			
		||||
        "stream": True,
 | 
			
		||||
    })
 | 
			
		||||
    content = ""
 | 
			
		||||
    last_cmpl_id = None
 | 
			
		||||
    for data in res:
 | 
			
		||||
        choice = data["choices"][0]
 | 
			
		||||
        assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
 | 
			
		||||
        if last_cmpl_id is None:
 | 
			
		||||
            last_cmpl_id = data["id"]
 | 
			
		||||
        assert last_cmpl_id == data["id"] # make sure the completion id is the same for all events in the stream
 | 
			
		||||
        if choice["finish_reason"] in ["stop", "length"]:
 | 
			
		||||
            assert data["usage"]["prompt_tokens"] == n_prompt
 | 
			
		||||
            assert data["usage"]["completion_tokens"] == n_predicted
 | 
			
		||||
 
 | 
			
		||||
@@ -64,6 +64,7 @@ class ServerProcess:
 | 
			
		||||
    server_embeddings: bool | None = False
 | 
			
		||||
    server_reranking: bool | None = False
 | 
			
		||||
    server_metrics: bool | None = False
 | 
			
		||||
    server_slots: bool | None = False
 | 
			
		||||
    draft: int | None = None
 | 
			
		||||
    api_key: str | None = None
 | 
			
		||||
    response_format: str | None = None
 | 
			
		||||
@@ -91,7 +92,6 @@ class ServerProcess:
 | 
			
		||||
        else:
 | 
			
		||||
            server_path = "../../../build/bin/llama-server"
 | 
			
		||||
        server_args = [
 | 
			
		||||
            "--slots",  # requires to get slot status via /slots endpoint
 | 
			
		||||
            "--host",
 | 
			
		||||
            self.server_host,
 | 
			
		||||
            "--port",
 | 
			
		||||
@@ -129,6 +129,8 @@ class ServerProcess:
 | 
			
		||||
            server_args.append("--reranking")
 | 
			
		||||
        if self.server_metrics:
 | 
			
		||||
            server_args.append("--metrics")
 | 
			
		||||
        if self.server_slots:
 | 
			
		||||
            server_args.append("--slots")
 | 
			
		||||
        if self.model_alias:
 | 
			
		||||
            server_args.extend(["--alias", self.model_alias])
 | 
			
		||||
        if self.n_ctx:
 | 
			
		||||
@@ -181,7 +183,7 @@ class ServerProcess:
 | 
			
		||||
        start_time = time.time()
 | 
			
		||||
        while time.time() - start_time < timeout_seconds:
 | 
			
		||||
            try:
 | 
			
		||||
                response = self.make_request("GET", "/slots", headers={
 | 
			
		||||
                response = self.make_request("GET", "/health", headers={
 | 
			
		||||
                    "Authorization": f"Bearer {self.api_key}" if self.api_key else None
 | 
			
		||||
                })
 | 
			
		||||
                if response.status_code == 200:
 | 
			
		||||
@@ -224,7 +226,7 @@ class ServerProcess:
 | 
			
		||||
        result.headers = dict(response.headers)
 | 
			
		||||
        result.status_code = response.status_code
 | 
			
		||||
        result.body = response.json() if parse_body else None
 | 
			
		||||
        print("Response from server", result.body)
 | 
			
		||||
        print("Response from server", json.dumps(result.body, indent=2))
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    def make_stream_request(
 | 
			
		||||
@@ -245,7 +247,7 @@ class ServerProcess:
 | 
			
		||||
                break
 | 
			
		||||
            elif line.startswith('data: '):
 | 
			
		||||
                data = json.loads(line[6:])
 | 
			
		||||
                print("Partial response from server", data)
 | 
			
		||||
                print("Partial response from server", json.dumps(data, indent=2))
 | 
			
		||||
                yield data
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -164,6 +164,9 @@ static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, con
 | 
			
		||||
    } else {
 | 
			
		||||
        throw std::runtime_error("\"prompt\" must be a string, an list of tokens, a list of mixed strings & tokens, or a list of prompts");
 | 
			
		||||
    }
 | 
			
		||||
    if (result.empty()) {
 | 
			
		||||
        throw std::runtime_error("\"prompt\" must not be empty");
 | 
			
		||||
    }
 | 
			
		||||
    return result;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -496,8 +499,6 @@ static json oaicompat_completion_params_parse(
 | 
			
		||||
    const std::string & chat_template) {
 | 
			
		||||
    json llama_params;
 | 
			
		||||
 | 
			
		||||
    llama_params["__oaicompat"] = true;
 | 
			
		||||
 | 
			
		||||
    // Apply chat template to the list of messages
 | 
			
		||||
    llama_params["prompt"] = format_chat(model, chat_template, body.at("messages"));
 | 
			
		||||
 | 
			
		||||
@@ -648,3 +649,18 @@ static json format_detokenized_response(const std::string & content) {
 | 
			
		||||
        {"content", content}
 | 
			
		||||
    };
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static json format_logit_bias(const std::vector<llama_logit_bias> & logit_bias) {
 | 
			
		||||
    json data = json::array();
 | 
			
		||||
    for (const auto & lb : logit_bias) {
 | 
			
		||||
        data.push_back(json{
 | 
			
		||||
            {"bias", lb.bias},
 | 
			
		||||
            {"token", lb.token},
 | 
			
		||||
        });
 | 
			
		||||
    }
 | 
			
		||||
    return data;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
static std::string safe_json_to_str(json data) {
 | 
			
		||||
    return data.dump(-1, ' ', false, json::error_handler_t::replace);
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user