mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-29 08:41:22 +00:00 
			
		
		
		
	 e81b8e4b7f
			
		
	
	e81b8e4b7f
	
	
	
		
			
			* llama: use max. GPU layers by default, auto -fa * ggml-backend: abort instead of segfault
		
			
				
	
	
		
			298 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
			
		
		
	
	
			298 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
| #!/usr/bin/env python3
 | |
| 
 | |
| import argparse
 | |
| import json
 | |
| import os
 | |
| import random
 | |
| import sqlite3
 | |
| import subprocess
 | |
| from time import sleep, time
 | |
| from typing import Optional, Union
 | |
| 
 | |
| import datasets
 | |
| import logging
 | |
| import matplotlib.pyplot as plt
 | |
| import numpy as np
 | |
| import requests
 | |
| from tqdm.contrib.concurrent import thread_map
 | |
| 
 | |
| 
 | |
| logging.basicConfig(level=logging.INFO, format='%(message)s')
 | |
| logger = logging.getLogger("server-bench")
 | |
| 
 | |
| 
 | |
| def get_prompts_text(dataset_name: str, n_prompts: int) -> Optional[list[str]]:
 | |
|     ret = []
 | |
|     if dataset_name.lower() == "mmlu":
 | |
|         logger.info("Loading MMLU dataset...")
 | |
|         ret = datasets.load_dataset("cais/mmlu", "all")["test"]["question"]  # type: ignore
 | |
|     else:
 | |
|         return None
 | |
|     if n_prompts >= 0:
 | |
|         ret = ret[:n_prompts]
 | |
|     return ret
 | |
| 
 | |
| 
 | |
| def get_prompt_lengths_rng(n_prompts: int, prompt_length_min: int, prompt_length_max: int, seed_offset: int) -> list[int]:
 | |
|     assert n_prompts >= 0
 | |
|     ret: list[int] = []
 | |
|     for i in range(n_prompts):
 | |
|         if seed_offset >= 0:
 | |
|             random.seed(3 * (seed_offset + 1000 * i) + 0)
 | |
|         ret.append(random.randint(prompt_length_min, prompt_length_max))
 | |
|     return ret
 | |
| 
 | |
| 
 | |
| def get_prompts_rng(prompt_lengths: list[int]) -> list[list[int]]:
 | |
|     return [[random.randint(100, 10000) for _ in range(pl)] for pl in prompt_lengths]
 | |
| 
 | |
| 
 | |
| def get_server(path_server: str, path_log: Optional[str]) -> dict:
 | |
|     if path_server.startswith("http://") or path_server.startswith("https://"):
 | |
|         return {"process": None, "address": path_server, "fout": None}
 | |
|     if os.environ.get("LLAMA_ARG_HOST") is None:
 | |
|         logger.info("LLAMA_ARG_HOST not explicitly set, using 127.0.0.1")
 | |
|         os.environ["LLAMA_ARG_HOST"] = "127.0.0.1"
 | |
|     if os.environ.get("LLAMA_ARG_PORT") is None:
 | |
|         logger.info("LLAMA_ARG_PORT not explicitly set, using 8080")
 | |
|         os.environ["LLAMA_ARG_PORT"] = "8080"
 | |
|     hostname: Optional[str] = os.environ.get("LLAMA_ARG_HOST")
 | |
|     port: Optional[str] = os.environ.get("LLAMA_ARG_PORT")
 | |
|     assert hostname is not None
 | |
|     assert port is not None
 | |
|     address: str = f"http://{hostname}:{port}"
 | |
|     logger.info(f"Starting the llama.cpp server under {address}...")
 | |
| 
 | |
|     fout = open(path_log.format(port=port), "w") if path_log is not None else subprocess.DEVNULL
 | |
|     process = subprocess.Popen([path_server], stdout=fout, stderr=subprocess.STDOUT)
 | |
| 
 | |
|     n_failures: int = 0
 | |
|     while True:
 | |
|         try:
 | |
|             sleep(1.0)
 | |
|             exit_code = process.poll()
 | |
|             if exit_code is not None:
 | |
|                 raise RuntimeError(f"llama.cpp server exited unexpectedly with exit code {exit_code}{path_log and f', see {path_log.format(port=port)}' or ''}")
 | |
|             response = requests.get(f"{address}/health")
 | |
|             if response.status_code == 200:
 | |
|                 break
 | |
|         except requests.ConnectionError:
 | |
|             n_failures += 1
 | |
|             if n_failures >= 10:
 | |
|                 raise RuntimeError("llama.cpp server is not healthy after 10 seconds")
 | |
| 
 | |
|     return {"process": process, "address": address, "fout": fout}
 | |
| 
 | |
| 
 | |
| def get_prompt_length(data: dict) -> int:
 | |
|     session = data["session"]
 | |
|     server_address: str = data["server_address"]
 | |
| 
 | |
|     response = session.post(
 | |
|         f"{server_address}/apply-template",
 | |
|         json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
 | |
|     )
 | |
|     response.raise_for_status()
 | |
|     prompt: str = json.loads(response.text)["prompt"]
 | |
|     response = session.post(
 | |
|         f"{server_address}/tokenize",
 | |
|         json={"content": prompt, "add_special": True}
 | |
|     )
 | |
|     response.raise_for_status()
 | |
|     tokens: list[str] = json.loads(response.text)["tokens"]
 | |
|     return len(tokens)
 | |
| 
 | |
| 
 | |
| def send_prompt(data: dict) -> tuple[float, list[float]]:
 | |
|     session = data["session"]
 | |
|     server_address: str = data["server_address"]
 | |
| 
 | |
|     t_submit = time()
 | |
|     if data["external_server"]:
 | |
|         json_data: dict = {
 | |
|             "prompt": data["prompt"], "ignore_eos": True,
 | |
|             "seed": data["seed"], "max_tokens": data["n_predict"], "stream": True}
 | |
|         response = session.post(f"{server_address}/v1/completions", json=json_data, stream=True)
 | |
|     elif data["synthetic_prompt"]:
 | |
|         json_data: dict = {
 | |
|             "prompt": data["prompt"], "ignore_eos": True, "cache_prompt": False,
 | |
|             "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
 | |
|         response = session.post(f"{server_address}/completion", json=json_data, stream=True)
 | |
|     else:
 | |
|         response = session.post(
 | |
|             f"{server_address}/apply-template",
 | |
|             json={"messages": [{"role": "user", "content": data["prompt"], "stream": True}]}
 | |
|         )
 | |
|         response.raise_for_status()
 | |
|         prompt: str = json.loads(response.text)["prompt"]
 | |
| 
 | |
|         json_data: dict = {"prompt": prompt, "seed": data["seed"], "n_predict": data["n_predict"], "stream": True}
 | |
|         response = session.post(f"{server_address}/completion", json=json_data, stream=True)
 | |
|     response.raise_for_status()
 | |
| 
 | |
|     lines = []
 | |
|     token_arrival_times: list[float] = []
 | |
|     for line in response.iter_lines(decode_unicode=False):
 | |
|         if not line.startswith(b"data: "):
 | |
|             continue
 | |
|         lines.append(line)
 | |
|         token_arrival_times.append(time())
 | |
|     token_arrival_times = token_arrival_times[:-1]
 | |
|     if len(lines) > 1 and "timings" in json.loads(lines[-2][6:]):
 | |
|         token_arrival_times = token_arrival_times[:-1]
 | |
| 
 | |
|     return (t_submit, token_arrival_times)
 | |
| 
 | |
| 
 | |
| def benchmark(
 | |
|         path_server: str, path_log: Optional[str], path_db: Optional[str], name: Optional[str], prompt_source: str, n_prompts: int,
 | |
|         n_predict: int, n_predict_min: int, seed_offset: int):
 | |
|     external_server: bool = path_server.startswith("http://") or path_server.startswith("https://")
 | |
|     if os.environ.get("LLAMA_ARG_N_PARALLEL") is None:
 | |
|         logger.info("LLAMA_ARG_N_PARALLEL not explicitly set, using 32")
 | |
|         os.environ["LLAMA_ARG_N_PARALLEL"] = "32"
 | |
| 
 | |
|     parallel: int = int(os.environ.get("LLAMA_ARG_N_PARALLEL")) # type: ignore
 | |
|     prompts: Union[None, list[str], list[list[int]]] = get_prompts_text(prompt_source, n_prompts)
 | |
|     synthetic_prompts: bool = prompts is None
 | |
|     prompt_n = []
 | |
| 
 | |
|     if synthetic_prompts:
 | |
|         prompt_source_split: list[str] = prompt_source.split("-")
 | |
|         assert len(prompt_source_split) == 3
 | |
|         assert prompt_source_split[0].lower() == "rng"
 | |
|         prompt_length_min: int = int(prompt_source_split[1])
 | |
|         prompt_length_max: int = int(prompt_source_split[2])
 | |
|         logger.info("Generating random prompts...")
 | |
|         prompt_n = get_prompt_lengths_rng(n_prompts, prompt_length_min, prompt_length_max, seed_offset)
 | |
|         prompts = get_prompts_rng(prompt_n)
 | |
|     else:
 | |
|         n_predict_min = n_predict
 | |
| 
 | |
|     if not external_server and os.environ.get("LLAMA_ARG_CTX_SIZE") is None:
 | |
|         context_per_slot: int = int(1.05 * (n_predict + (np.max(prompt_n) if synthetic_prompts else 2048)))
 | |
|         context_total: int = context_per_slot * parallel
 | |
|         os.environ["LLAMA_ARG_CTX_SIZE"] = str(context_total)
 | |
|         logger.info(f"LLAMA_ARG_CTX_SIZE not explicitly set, using {context_total} ({context_per_slot} per slot).")
 | |
| 
 | |
|     server: Optional[dict] = None
 | |
|     session = None
 | |
|     try:
 | |
|         server = get_server(path_server, path_log)
 | |
|         server_address: str = server["address"]
 | |
|         assert external_server == (server["process"] is None)
 | |
| 
 | |
|         adapter = requests.adapters.HTTPAdapter(pool_connections=parallel, pool_maxsize=parallel)  # type: ignore
 | |
|         session = requests.Session()
 | |
|         session.mount("http://", adapter)
 | |
|         session.mount("https://", adapter)
 | |
| 
 | |
|         data: list[dict] = []
 | |
| 
 | |
|         for i, p in enumerate(prompts):
 | |
|             if seed_offset >= 0:
 | |
|                 random.seed(3 * (seed_offset + 1000 * i) + 1)
 | |
|             data.append({
 | |
|                 "session": session, "server_address": server_address, "external_server": external_server, "prompt": p,
 | |
|                 "synthetic_prompt": synthetic_prompts, "n_predict": random.randint(n_predict_min, n_predict),
 | |
|                 "seed": (3 * (seed_offset + 1000 * i) + 2) if seed_offset >= 0 else -1})
 | |
| 
 | |
|         if not synthetic_prompts:
 | |
|             logger.info("Getting the prompt lengths...")
 | |
|             prompt_n = [get_prompt_length(d) for d in data]
 | |
| 
 | |
|         logger.info("Starting the benchmark...\n")
 | |
|         t0 = time()
 | |
|         results: list[tuple[float, list[float]]] = thread_map(send_prompt, data, max_workers=parallel, chunksize=1)
 | |
|     finally:
 | |
|         if server is not None and server["process"] is not None:
 | |
|             server["process"].terminate()
 | |
|             server["process"].wait()
 | |
|         if session is not None:
 | |
|             session.close()
 | |
| 
 | |
|     prompt_t = []
 | |
|     token_t = []
 | |
|     depth_sum: int = 0
 | |
|     for pn, (t_submit, tat) in zip(prompt_n, results):
 | |
|         prompt_t.append(tat[0] - t_submit)
 | |
|         token_t += tat
 | |
|         n_tokens: int = len(tat)
 | |
|         depth_sum += n_tokens * pn
 | |
|         depth_sum += n_tokens * (n_tokens + 1) // 2
 | |
|     assert len(token_t) > 0
 | |
|     prompt_n = np.array(prompt_n, dtype=np.int64)
 | |
|     prompt_t = np.array(prompt_t, dtype=np.float64)
 | |
|     token_t = np.array(token_t, dtype=np.float64)
 | |
| 
 | |
|     token_t -= t0
 | |
|     token_t_last = np.max(token_t)
 | |
| 
 | |
|     logger.info("")
 | |
|     logger.info(f"Benchmark duration:                {token_t_last:.2f} s")
 | |
|     logger.info(f"Request throughput:                {n_prompts / token_t_last:.2f} requests/s = {n_prompts / (token_t_last/60):.2f} requests/min")
 | |
|     logger.info(f"Total prompt length:               {np.sum(prompt_n)} tokens")
 | |
|     logger.info(f"Average prompt length:             {np.mean(prompt_n):.2f} tokens")
 | |
|     logger.info(f"Average prompt latency:            {1e3 * np.mean(prompt_t):.2f} ms")
 | |
|     logger.info(f"Average prompt speed:              {np.sum(prompt_n) / np.sum(prompt_t):.2f} tokens/s")
 | |
|     logger.info(f"Total generated tokens:            {token_t.shape[0]}")
 | |
|     logger.info(f"Average generation depth:          {depth_sum / token_t.shape[0]:.2f} tokens")
 | |
|     logger.info(f"Average total generation speed:    {token_t.shape[0] / token_t_last:.2f} tokens/s")
 | |
|     logger.info(f"Average generation speed per slot: {token_t.shape[0] / (parallel * token_t_last):.2f} tokens/s / slot")
 | |
| 
 | |
|     if path_db is not None:
 | |
|         con = sqlite3.connect(path_db)
 | |
|         cursor = con.cursor()
 | |
|         cursor.execute(
 | |
|             "CREATE TABLE IF NOT EXISTS server_bench"
 | |
|             "(name TEXT, n_parallel INTEGER, prompt_source TEXT, n_prompts INTEGER, "
 | |
|             "n_predict INTEGER, n_predict_min INTEGER, seed_offset INTEGER, runtime REAL);")
 | |
|         cursor.execute(
 | |
|             "INSERT INTO server_bench VALUES (?, ?, ?, ?, ?, ?, ?, ?);",
 | |
|             [name, parallel, prompt_source, n_prompts, n_predict, n_predict_min, seed_offset, token_t_last])
 | |
|         con.commit()
 | |
| 
 | |
|     plt.figure()
 | |
|     plt.scatter(prompt_n, 1e3 * prompt_t, s=10.0, marker=".", alpha=0.25)
 | |
|     plt.xlim(0, 1.05e0 * np.max(prompt_n))
 | |
|     plt.ylim(0, 1.05e3 * np.max(prompt_t))
 | |
|     plt.title(name or "")
 | |
|     plt.xlabel("Prompt length [tokens]")
 | |
|     plt.ylabel("Time to first token [ms]")
 | |
|     plt.savefig("prompt_time.png", dpi=240)
 | |
| 
 | |
|     bin_max = np.ceil(token_t_last) + 1
 | |
|     plt.figure()
 | |
|     plt.hist(token_t, np.arange(0, bin_max))
 | |
|     plt.xlim(0, bin_max + 1)
 | |
|     plt.title(name or "")
 | |
|     plt.xlabel("Time [s]")
 | |
|     plt.ylabel("Num. tokens generated per second")
 | |
|     plt.savefig("gen_rate.png", dpi=240)
 | |
| 
 | |
| 
 | |
| if __name__ == "__main__":
 | |
|     parser = argparse.ArgumentParser(
 | |
|         description="Tool for benchmarking the throughput of the llama.cpp HTTP server. "
 | |
|         "Results are printed to console and visualized as plots (saved to current working directory). "
 | |
|         "To pass arguments such as the model path to the server, set the corresponding environment variables (see llama-server --help). "
 | |
|         "The reported numbers are the speeds as observed by the Python script and may differ from the performance reported by the server, "
 | |
|         "particularly when the server is fast vs. the network or Python script (e.g. when serving a very small model).")
 | |
|     parser.add_argument("--path_server", type=str, default="llama-server", help="Path to the llama.cpp server binary")
 | |
|     parser.add_argument("--path_log", type=str, default="server-bench-{port}.log", help="Path to the model to use for the benchmark")
 | |
|     parser.add_argument("--path_db", type=str, default=None, help="Path to an sqlite database to store the benchmark results in")
 | |
|     parser.add_argument("--name", type=str, default=None, help="Name to label plots and database entries with")
 | |
|     parser.add_argument(
 | |
|         "--prompt_source", type=str, default="rng-1024-2048",
 | |
|         help="How to get the prompts for the benchmark, either 'mmlu' for MMLU questions or "
 | |
|         "rng-MIN-MAX for synthetic prompts with random lengths in the interval [MIN, MAX]")
 | |
|     parser.add_argument("--n_prompts", type=int, default=100, help="Number of prompts to evaluate")
 | |
|     parser.add_argument("--n_predict", type=int, default=2048, help="Max. number of tokens to predict per prompt")
 | |
|     parser.add_argument(
 | |
|         "--n_predict_min", type=int, default=1024,
 | |
|         help="Min. number of tokens to predict per prompt (supported for synthetic prompts only)")
 | |
|     parser.add_argument("--seed_offset", type=int, default=0, help="Offset for determining the seeds for pseudorandom prompt/generation lengths. "
 | |
|                         "Corelations between seeds can occur when set >= 1000. Negative values mean no seed.")
 | |
|     args = parser.parse_args()
 | |
|     benchmark(**vars(args))
 |