mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	scripts: benchmark for HTTP server throughput (#14668)
* scripts: benchmark for HTTP server throughput * fix server connection reset
This commit is contained in:
		| @@ -3,6 +3,7 @@ | |||||||
| -r ../tools/server/tests/requirements.txt | -r ../tools/server/tests/requirements.txt | ||||||
|  |  | ||||||
| -r ./requirements-compare-llama-bench.txt | -r ./requirements-compare-llama-bench.txt | ||||||
|  | -r ./requirements-server-bench.txt | ||||||
| -r ./requirements-pydantic.txt | -r ./requirements-pydantic.txt | ||||||
| -r ./requirements-test-tokenizer-random.txt | -r ./requirements-test-tokenizer-random.txt | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										5
									
								
								requirements/requirements-server-bench.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								requirements/requirements-server-bench.txt
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | |||||||
|  | datasets~=3.2.0 | ||||||
|  | matplotlib~=3.10.0 | ||||||
|  | numpy~=1.26.4 | ||||||
|  | requests~=2.32.3 | ||||||
|  | tqdm~=4.67.1 | ||||||
							
								
								
									
										210
									
								
								scripts/server-bench.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										210
									
								
								scripts/server-bench.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,210 @@ | |||||||
|  | #!/usr/bin/env python3 | ||||||
|  |  | ||||||
|  | import argparse | ||||||
|  | import json | ||||||
|  | import subprocess | ||||||
|  | from time import sleep, time | ||||||
|  | from typing import Optional | ||||||
|  |  | ||||||
|  | import datasets | ||||||
|  | import logging | ||||||
|  | import matplotlib.pyplot as plt | ||||||
|  | import numpy as np | ||||||
|  | import requests | ||||||
|  | from tqdm.contrib.concurrent import thread_map | ||||||
|  |  | ||||||
|  |  | ||||||
|  | logging.basicConfig(level=logging.INFO, format='%(message)s') | ||||||
|  | logger = logging.getLogger("server-bench") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_prompts(n_prompts: int) -> list[str]: | ||||||
|  |     logger.info("Loading MMLU dataset...") | ||||||
|  |     ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"]  # type: ignore | ||||||
|  |     if n_prompts >= 0: | ||||||
|  |         ret = ret[:n_prompts] | ||||||
|  |     return ret | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_server(path_server: str, path_model: str, path_log: Optional[str], port: int, n_gpu_layers: int, parallel: int, ctx_size: int) -> dict: | ||||||
|  |     logger.info("Starting the llama.cpp server...") | ||||||
|  |     address = f"http://localhost:{port}" | ||||||
|  |  | ||||||
|  |     popen_args: list[str] = [ | ||||||
|  |         path_server, | ||||||
|  |         "--flash-attn", | ||||||
|  |         "--n-gpu-layers", str(n_gpu_layers), | ||||||
|  |         "--parallel", str(parallel), | ||||||
|  |         "--ctx-size", str(parallel * ctx_size), | ||||||
|  |         "--model", path_model, | ||||||
|  |         "--port", str(port), | ||||||
|  |         "--swa-full",  # FIXME performance bad otherwise | ||||||
|  |         # "--attn-streams", | ||||||
|  |     ] | ||||||
|  |     fout = open("bench.log", "w") if path_log is not None else subprocess.DEVNULL | ||||||
|  |     process = subprocess.Popen(popen_args, stdout=fout, stderr=subprocess.STDOUT) | ||||||
|  |  | ||||||
|  |     n_failures: int = 0 | ||||||
|  |     while True: | ||||||
|  |         try: | ||||||
|  |             sleep(1.0) | ||||||
|  |             exit_code = process.poll() | ||||||
|  |             if exit_code is not None: | ||||||
|  |                 raise RuntimeError(f"llama.cpp server for {path_model} exited unexpectedly with exit code {exit_code}") | ||||||
|  |             response = requests.get(f"{address}/health") | ||||||
|  |             if response.status_code == 200: | ||||||
|  |                 break | ||||||
|  |         except requests.ConnectionError: | ||||||
|  |             n_failures += 1 | ||||||
|  |             if n_failures >= 10: | ||||||
|  |                 raise RuntimeError(f"llama.cpp server for {path_model} is not healthy after 10 seconds") | ||||||
|  |  | ||||||
|  |     return {"process": process, "address": address, "fout": fout} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_prompt_length(data: dict) -> int: | ||||||
|  |     session = data["session"] | ||||||
|  |     server_address: str = data["server_address"] | ||||||
|  |  | ||||||
|  |     response = session.post( | ||||||
|  |         f"{server_address}/apply-template", | ||||||
|  |         json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]} | ||||||
|  |     ) | ||||||
|  |     if response.status_code != 200: | ||||||
|  |         raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}") | ||||||
|  |     prompt: str = json.loads(response.text)["prompt"] | ||||||
|  |     response = session.post( | ||||||
|  |         f"{server_address}/tokenize", | ||||||
|  |         json={"content": prompt, "add_special": True} | ||||||
|  |     ) | ||||||
|  |     if response.status_code != 200: | ||||||
|  |         raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}") | ||||||
|  |     tokens: list[str] = json.loads(response.text)["tokens"] | ||||||
|  |     return len(tokens) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def send_prompt(data: dict) -> tuple[float, list[float]]: | ||||||
|  |     session = data["session"] | ||||||
|  |     server_address: str = data["server_address"] | ||||||
|  |  | ||||||
|  |     response = session.post( | ||||||
|  |         f"{server_address}/apply-template", | ||||||
|  |         json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]} | ||||||
|  |     ) | ||||||
|  |     if response.status_code != 200: | ||||||
|  |         raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}") | ||||||
|  |     prompt: str = json.loads(response.text)["prompt"] | ||||||
|  |  | ||||||
|  |     json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True} | ||||||
|  |     response = session.post(f"{server_address}/completion", json=json_data, stream=True) | ||||||
|  |  | ||||||
|  |     last_valid_line: str = "" | ||||||
|  |     token_arrival_times: list[float] = [] | ||||||
|  |     for line in response.iter_lines(decode_unicode=True): | ||||||
|  |         if not line.startswith("data: "): | ||||||
|  |             continue | ||||||
|  |         last_valid_line = line | ||||||
|  |         token_arrival_times.append(time()) | ||||||
|  |     token_arrival_times = token_arrival_times[:-1] | ||||||
|  |  | ||||||
|  |     if response.status_code != 200: | ||||||
|  |         raise RuntimeError(f"Server returned status code {response.status_code}: {response.text}") | ||||||
|  |     timings: dict = json.loads(last_valid_line[6:])["timings"] | ||||||
|  |  | ||||||
|  |     return (timings["prompt_ms"], token_arrival_times) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def benchmark(path_server: str, path_model: str, path_log: Optional[str], port: int, n_gpu_layers: int, parallel: int, ctx_size: int, n_prompts: int, n_predict: int): | ||||||
|  |     num_workers: int = parallel + 1 | ||||||
|  |     prompts: list[str] = get_prompts(n_prompts) | ||||||
|  |  | ||||||
|  |     server: Optional[dict] = None | ||||||
|  |     session = None | ||||||
|  |     try: | ||||||
|  |         server = get_server(path_server, path_model, path_log, port, n_gpu_layers, parallel, ctx_size) | ||||||
|  |         server_address: str = server["address"] | ||||||
|  |  | ||||||
|  |         adapter = requests.adapters.HTTPAdapter(pool_connections=num_workers, pool_maxsize=num_workers)  # type: ignore | ||||||
|  |         session = requests.Session() | ||||||
|  |         session.mount("http://", adapter) | ||||||
|  |         session.mount("https://", adapter) | ||||||
|  |  | ||||||
|  |         data: list[dict] = [] | ||||||
|  |         for i, p in enumerate(prompts): | ||||||
|  |             data.append({"session": session, "server_address": server_address, "prompt": p, "n_predict": n_predict, "seed": i}) | ||||||
|  |  | ||||||
|  |         logger.info("Getting the prompt lengths...") | ||||||
|  |         prompt_n = [get_prompt_length(d) for d in data] | ||||||
|  |  | ||||||
|  |         logger.info("Starting the benchmark...\n") | ||||||
|  |         t0 = time() | ||||||
|  |         results: list[tuple[int, list[float]]] = thread_map(send_prompt, data, max_workers=num_workers, chunksize=1) | ||||||
|  |     finally: | ||||||
|  |         if server is not None: | ||||||
|  |             server["process"].terminate() | ||||||
|  |             server["process"].wait() | ||||||
|  |         if session is not None: | ||||||
|  |             session.close() | ||||||
|  |  | ||||||
|  |     prompt_ms = [] | ||||||
|  |     token_t = [] | ||||||
|  |     depth_sum: int = 0 | ||||||
|  |     for pn, (pms, tat) in zip(prompt_n, results): | ||||||
|  |         prompt_ms.append(pms) | ||||||
|  |         token_t += tat | ||||||
|  |         n_tokens: int = len(tat) | ||||||
|  |         depth_sum += n_tokens * pn | ||||||
|  |         depth_sum += n_tokens * (n_tokens + 1) // 2 | ||||||
|  |     prompt_n = np.array(prompt_n, dtype=np.int64) | ||||||
|  |     prompt_ms = np.array(prompt_ms, dtype=np.float64) | ||||||
|  |     token_t = np.array(token_t, dtype=np.float64) | ||||||
|  |  | ||||||
|  |     token_t -= t0 | ||||||
|  |     token_t_last = np.max(token_t) | ||||||
|  |  | ||||||
|  |     logger.info("") | ||||||
|  |     logger.info(f"Benchmark duration:                {token_t_last:.2f} s") | ||||||
|  |     logger.info(f"Request throughput:                {n_prompts / token_t_last:.2f} requests/s = {n_prompts / (token_t_last/60):.2f} requests/min") | ||||||
|  |     logger.info(f"Total prompt length:               {np.sum(prompt_n)} tokens") | ||||||
|  |     logger.info(f"Average prompt length:             {np.mean(prompt_n):.2f} tokens") | ||||||
|  |     logger.info(f"Average prompt latency:            {np.mean(prompt_ms):.2f} ms") | ||||||
|  |     logger.info(f"Average prompt speed:              {np.sum(prompt_n) / (1e-3 * np.sum(prompt_ms)):.2f} tokens/s") | ||||||
|  |     logger.info(f"Total generated tokens:            {token_t.shape[0]}") | ||||||
|  |     logger.info(f"Average generation depth:          {depth_sum / token_t.shape[0]:.2f} tokens") | ||||||
|  |     logger.info(f"Average total generation speed:    {token_t.shape[0] / token_t_last:.2f} tokens/s") | ||||||
|  |     logger.info(f"Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot") | ||||||
|  |  | ||||||
|  |     plt.figure() | ||||||
|  |     plt.scatter(prompt_n, prompt_ms, s=10.0, marker=".", alpha=0.25) | ||||||
|  |     plt.xlim(0, 1.05 * np.max(prompt_n)) | ||||||
|  |     plt.ylim(0, 1.05 * np.max(prompt_ms)) | ||||||
|  |     plt.title(path_model) | ||||||
|  |     plt.xlabel("Prompt length [tokens]") | ||||||
|  |     plt.ylabel("Time to first token [ms]") | ||||||
|  |     plt.savefig("prompt_time.png", dpi=240) | ||||||
|  |  | ||||||
|  |     bin_max = np.ceil(token_t_last) + 1 | ||||||
|  |     plt.figure() | ||||||
|  |     plt.hist(token_t, np.arange(0, bin_max)) | ||||||
|  |     plt.xlim(0, bin_max + 1) | ||||||
|  |     plt.title(path_model) | ||||||
|  |     plt.xlabel("Time [s]") | ||||||
|  |     plt.ylabel("Num. tokens generated per second") | ||||||
|  |     plt.savefig("gen_rate.png", dpi=240) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     parser = argparse.ArgumentParser( | ||||||
|  |         description="Tool for benchmarking the throughput of the llama.cpp HTTP server. " | ||||||
|  |         "Results are printed to console and visualized as plots (saved to current working directory).") | ||||||
|  |     parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary") | ||||||
|  |     parser.add_argument("--path_model", type=str, required=True, help="Path to the model to use for the benchmark") | ||||||
|  |     parser.add_argument("--path_log", type=str, default=None, help="Path to the model to use for the benchmark") | ||||||
|  |     parser.add_argument("--port", type=int, default=18725, help="Port to use for the server during the benchmark") | ||||||
|  |     parser.add_argument("--n_gpu_layers", type=int, default=999, help="Number of GPU layers for the server") | ||||||
|  |     parser.add_argument("--parallel", type=int, default=16, help="Number of slots for the server") | ||||||
|  |     parser.add_argument("--ctx_size", type=int, default=4096, help="Server context size per slot") | ||||||
|  |     parser.add_argument("--n_prompts", type=int, default=1000, help="Number of prompts to evaluate") | ||||||
|  |     parser.add_argument("--n_predict", type=int, default=2048, help="Max. number of tokens to predict per prompt") | ||||||
|  |     args = parser.parse_args() | ||||||
|  |     benchmark(**vars(args)) | ||||||
| @@ -11,6 +11,8 @@ | |||||||
|  |  | ||||||
| // increase max payload length to allow use of larger context size | // increase max payload length to allow use of larger context size | ||||||
| #define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576 | #define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576 | ||||||
|  | // increase backlog size to avoid connection resets for >> 1 slots | ||||||
|  | #define CPPHTTPLIB_LISTEN_BACKLOG 512 | ||||||
| // disable Nagle's algorithm | // disable Nagle's algorithm | ||||||
| #define CPPHTTPLIB_TCP_NODELAY true | #define CPPHTTPLIB_TCP_NODELAY true | ||||||
| #include <cpp-httplib/httplib.h> | #include <cpp-httplib/httplib.h> | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Johannes Gäßler
					Johannes Gäßler