mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-11-03 09:22:01 +00:00 
			
		
		
		
	* llama : move end-user examples to tools directory --------- Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
		
			
				
	
	
		
			369 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
			
		
		
	
	
			369 lines
		
	
	
		
			14 KiB
		
	
	
	
		
			Python
		
	
	
		
			Executable File
		
	
	
	
	
#!/usr/bin/env uv run
 | 
						|
'''
 | 
						|
    Simplistic tool call benchmarks for llama-server and ollama.
 | 
						|
 | 
						|
    Essentially runs the tests at server/tools/server/tests/unit/test_tool_call.py N times, at different temperatures and on different backends (current llama-server, baseline llama-server and ollama),
 | 
						|
    and plots the results of multiple runs (from same .jsonl file or multiple ones) as a success rate heatmap.
 | 
						|
 | 
						|
    Simple usage example:
 | 
						|
 | 
						|
        cmake -B build -DLLAMA_CURL=1 && cmake --build build --config Release -j -t llama-server
 | 
						|
 | 
						|
        export LLAMA_SERVER_BIN_PATH=$PWD/build/bin/llama-server
 | 
						|
        export LLAMA_CACHE=${LLAMA_CACHE:-$HOME/Library/Caches/llama.cpp}
 | 
						|
 | 
						|
        ./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 1.5B Q4_K_M"      --output qwen1.5b.jsonl  --hf bartowski/Qwen2.5-1.5B-Instruct-GGUF      --ollama qwen2.5:1.5b-instruct-q4_K_M
 | 
						|
        ./scripts/tool_bench.py run --n 30 --temp -1 --temp 0 --temp 1 --model "Qwen 2.5 Coder 7B Q4_K_M"  --output qwenc7b.jsonl   --hf bartowski/Qwen2.5-Coder-7B-Instruct-GGUF  --ollama qwen2.5-coder:7b
 | 
						|
 | 
						|
        ./scripts/tool_bench.py plot *.jsonl                         # Opens window w/ heatmap
 | 
						|
        ./scripts/tool_bench.py plot qwen*.jsonl  --output qwen.png  # Saves heatmap to qwen.png
 | 
						|
 | 
						|
    (please see ./scripts/tool_bench.sh for a more complete example)
 | 
						|
'''
 | 
						|
# /// script
 | 
						|
# requires-python = ">=3.10"
 | 
						|
# dependencies = [
 | 
						|
#     "pytest",
 | 
						|
#     "pandas",
 | 
						|
#     "matplotlib",
 | 
						|
#     "seaborn",
 | 
						|
#     "requests",
 | 
						|
#     "wget",
 | 
						|
#     "typer",
 | 
						|
# ]
 | 
						|
# ///
 | 
						|
from contextlib import contextmanager
 | 
						|
from pathlib import Path
 | 
						|
import re
 | 
						|
from statistics import mean, median
 | 
						|
from typing import Annotated, Dict, List, Optional, Tuple
 | 
						|
import atexit
 | 
						|
import json
 | 
						|
import logging
 | 
						|
import matplotlib.pyplot as plt
 | 
						|
import numpy as np
 | 
						|
import pandas as pd
 | 
						|
import seaborn as sns
 | 
						|
import subprocess
 | 
						|
import sys
 | 
						|
import time
 | 
						|
import typer
 | 
						|
 | 
						|
sys.path.insert(0, Path(__file__).parent.parent.as_posix())
 | 
						|
if True:
 | 
						|
    from tools.server.tests.utils import ServerProcess
 | 
						|
    from tools.server.tests.unit.test_tool_call import TIMEOUT_SERVER_START, do_test_calc_result, do_test_hello_world, do_test_weather
 | 
						|
 | 
						|
 | 
						|
@contextmanager
 | 
						|
def scoped_server(sp: ServerProcess):
 | 
						|
    def stop():
 | 
						|
        nonlocal sp
 | 
						|
        if sp is not None:
 | 
						|
            sp.stop()
 | 
						|
            sp = None # type: ignore
 | 
						|
    atexit.register(stop)
 | 
						|
    yield sp
 | 
						|
    stop()
 | 
						|
 | 
						|
 | 
						|
logging.basicConfig(
 | 
						|
    level=logging.INFO,
 | 
						|
    format='%(asctime)s - %(levelname)s - %(message)s'
 | 
						|
)
 | 
						|
logger = logging.getLogger(__name__)
 | 
						|
 | 
						|
app = typer.Typer()
 | 
						|
 | 
						|
 | 
						|
@app.command()
 | 
						|
def plot(files: List[Path], output: Optional[Path] = None, test_regex: Optional[str] = None, server_regex: Optional[str] = None):
 | 
						|
 | 
						|
    lines: List[Dict] = []
 | 
						|
    for file in files:
 | 
						|
        if not file.exists():
 | 
						|
            logger.error(f"File not found: {file}")
 | 
						|
            continue
 | 
						|
 | 
						|
        try:
 | 
						|
            with file.open() as f:
 | 
						|
                raw_data = f.read()
 | 
						|
            logger.info(f"Reading {file} ({len(raw_data)} bytes)")
 | 
						|
 | 
						|
            for line_num, line in enumerate(raw_data.split('\n'), 1):
 | 
						|
                line = line.strip()
 | 
						|
                if not line:
 | 
						|
                    continue
 | 
						|
                try:
 | 
						|
                    record = json.loads(line)
 | 
						|
                    lines.append(record)
 | 
						|
                except json.JSONDecodeError as e:
 | 
						|
                    logger.warning(f"Invalid JSON at {file}:{line_num} - {e}")
 | 
						|
        except Exception as e:
 | 
						|
            logger.error(f"Error processing {file}: {e}")
 | 
						|
 | 
						|
    if not lines:
 | 
						|
        raise Exception("No valid data was loaded")
 | 
						|
 | 
						|
    data_dict: Dict[Tuple, float] = {}
 | 
						|
    models: List[str] = []
 | 
						|
    temps = set()
 | 
						|
    tests = set()
 | 
						|
    server_names = set()
 | 
						|
    total_counts = set()
 | 
						|
    for rec in lines:
 | 
						|
        try:
 | 
						|
            model = rec["model"]
 | 
						|
            temp = rec["temp"]
 | 
						|
            server_name = rec["server_name"]
 | 
						|
            test = rec["test"]
 | 
						|
            success = rec["success_ratio"]
 | 
						|
            success_count = rec["success_count"]
 | 
						|
            failure_count = rec["failure_count"]
 | 
						|
            total_count = success_count + failure_count
 | 
						|
            total_counts.add(total_count)
 | 
						|
 | 
						|
            if test_regex and not re.search(test_regex, test):
 | 
						|
                continue
 | 
						|
 | 
						|
            if server_regex and not re.search(server_regex, server_name):
 | 
						|
                continue
 | 
						|
 | 
						|
            data_dict[(model, temp, server_name, test)] = success
 | 
						|
 | 
						|
            if model not in models:
 | 
						|
                models.append(model)
 | 
						|
            temps.add(temp)
 | 
						|
            tests.add(test)
 | 
						|
            server_names.add(server_name)
 | 
						|
 | 
						|
        except KeyError as e:
 | 
						|
            logger.warning(f"Missing required field in record: {e}")
 | 
						|
 | 
						|
    if len(total_counts) > 1:
 | 
						|
        logger.warning(f"Total counts are not consistent: {total_counts}")
 | 
						|
 | 
						|
    # Sort the collected values
 | 
						|
    temps = list(sorted(temps, key=lambda x: x if x is not None else -1))
 | 
						|
    tests = list(sorted(tests))
 | 
						|
    server_names = list(sorted(server_names))
 | 
						|
 | 
						|
    logger.info(f"Processed {len(lines)} lines")
 | 
						|
    logger.info(f"Found {len(data_dict)} valid data points")
 | 
						|
    logger.info(f"Models: {models}")
 | 
						|
    logger.info(f"Temperatures: {temps}")
 | 
						|
    logger.info(f"Tests: {tests}")
 | 
						|
    logger.info(f"Servers: {server_names}")
 | 
						|
 | 
						|
    matrix: list[list[float]] = []
 | 
						|
    index: list[str] = []
 | 
						|
 | 
						|
    all_cols = [
 | 
						|
        (server_name, test)
 | 
						|
        for server_name in server_names
 | 
						|
        for test in tests
 | 
						|
    ]
 | 
						|
    for model in models:
 | 
						|
        for temp in temps:
 | 
						|
            index.append(f"{model} @ {temp}")
 | 
						|
            row_vals = [
 | 
						|
                data_dict.get((model, temp, server_name, test), np.nan)
 | 
						|
                for server_name, test in all_cols
 | 
						|
            ]
 | 
						|
            matrix.append(row_vals)
 | 
						|
 | 
						|
    columns: list[str] = [f"{server_name}\n{test}" for server_name, test in all_cols]
 | 
						|
 | 
						|
    df = pd.DataFrame(matrix, index=np.array(index), columns=np.array(columns))
 | 
						|
 | 
						|
    plt.figure(figsize=(12, 6))
 | 
						|
 | 
						|
    sns.heatmap(
 | 
						|
        df, annot=True, cmap="RdYlGn", vmin=0.0, vmax=1.0, cbar=True, fmt=".2f", center=0.5, square=True, linewidths=0.5,
 | 
						|
        cbar_kws={"label": "Success Ratio"},
 | 
						|
    )
 | 
						|
 | 
						|
    plt.title(f"Tool Call Bench (n = {str(min(total_counts)) if len(total_counts) == 1 else f'{min(total_counts)}-{max(total_counts)}'})\nSuccess Ratios by Server & Test", pad=20)
 | 
						|
    plt.xlabel("Server & Test", labelpad=10)
 | 
						|
    plt.ylabel("Model @ Temperature", labelpad=10)
 | 
						|
 | 
						|
    plt.xticks(rotation=45, ha='right')
 | 
						|
    plt.yticks(rotation=0)
 | 
						|
 | 
						|
    plt.tight_layout()
 | 
						|
 | 
						|
    if output:
 | 
						|
        plt.savefig(output, dpi=300, bbox_inches='tight')
 | 
						|
        logger.info(f"Plot saved to {output}")
 | 
						|
    else:
 | 
						|
        plt.show()
 | 
						|
 | 
						|
 | 
						|
@app.command()
 | 
						|
def run(
 | 
						|
    output: Annotated[Path, typer.Option(help="Output JSON file")],
 | 
						|
    model: Annotated[Optional[str], typer.Option(help="Name of the model to test (server agnostic)")] = None,
 | 
						|
    hf: Annotated[Optional[str], typer.Option(help="GGUF huggingface model repo id (+ optional quant) to test w/ llama-server")] = None,
 | 
						|
    chat_template: Annotated[Optional[str], typer.Option(help="Chat template override for llama-server")] = None,
 | 
						|
    ollama: Annotated[Optional[str], typer.Option(help="Ollama model tag to test")] = None,
 | 
						|
    llama_baseline: Annotated[Optional[str], typer.Option(help="llama-server baseline binary path to use as baseline")] = None,
 | 
						|
    n: Annotated[int, typer.Option(help="Number of times to run each test")] = 10,
 | 
						|
    temp: Annotated[Optional[List[float]], typer.Option(help="Set of temperatures to test")] = None,
 | 
						|
    top_p: Annotated[Optional[float], typer.Option(help="top_p")] = None,
 | 
						|
    top_k: Annotated[Optional[int], typer.Option(help="top_k")] = None,
 | 
						|
    ctk: Annotated[Optional[str], typer.Option(help="ctk")] = None,
 | 
						|
    ctv: Annotated[Optional[str], typer.Option(help="ctv")] = None,
 | 
						|
    fa: Annotated[Optional[bool], typer.Option(help="fa")] = None,
 | 
						|
    seed: Annotated[Optional[int], typer.Option(help="Random seed")] = None,
 | 
						|
    port: Annotated[int, typer.Option(help="llama-server port")] = 8084,
 | 
						|
    force: Annotated[bool, typer.Option(help="Force overwrite of output file")] = False,
 | 
						|
    append: Annotated[bool, typer.Option(help="Append to output file")] = False,
 | 
						|
 | 
						|
    test_hello_world: Annotated[bool, typer.Option(help="Whether to run the hello world test")] = True,
 | 
						|
    test_weather: Annotated[bool, typer.Option(help="Whether to run the weather test")] = True,
 | 
						|
    test_calc_result: Annotated[bool, typer.Option(help="Whether to run the calc result test")] = False,
 | 
						|
):
 | 
						|
    # Check only one of output and append
 | 
						|
 | 
						|
    n_predict = 512 # High because of DeepSeek R1
 | 
						|
    # n_ctx = 8192
 | 
						|
    n_ctx = 2048
 | 
						|
 | 
						|
    assert force or append or not output.exists(), f"Output file already exists: {output}; use --force to overwrite"
 | 
						|
 | 
						|
    with output.open('a' if append else 'w') as output_file:
 | 
						|
 | 
						|
        def run(server: ServerProcess, *, server_name: str, model_id: str, temp: Optional[float] = None, output_kwargs={}, request_kwargs={}):
 | 
						|
            request_kwargs = {**request_kwargs}
 | 
						|
            if temp is not None:
 | 
						|
                request_kwargs['temperature'] = temp
 | 
						|
            if top_p is not None:
 | 
						|
                request_kwargs['top_p'] = top_p
 | 
						|
            if top_k is not None:
 | 
						|
                request_kwargs['top_k'] = top_k
 | 
						|
            if seed is not None:
 | 
						|
                request_kwargs['seed'] = seed
 | 
						|
 | 
						|
            request_kwargs['cache_prompt'] = False
 | 
						|
 | 
						|
            tests = {}
 | 
						|
            if test_hello_world:
 | 
						|
                tests["hello world"] = lambda server: do_test_hello_world(server, **request_kwargs)
 | 
						|
            if test_weather:
 | 
						|
                tests["weather"] = lambda server: do_test_weather(server, **request_kwargs)
 | 
						|
            if test_calc_result:
 | 
						|
                tests["calc result"] = lambda server: do_test_calc_result(server, None, 512, **request_kwargs)
 | 
						|
 | 
						|
            for test_name, test in tests.items():
 | 
						|
                success_count = 0
 | 
						|
                failure_count = 0
 | 
						|
                failures = []
 | 
						|
                success_times = []
 | 
						|
                failure_times = []
 | 
						|
                logger.info(f"Running {test_name} ({server_name}, {model}): ")
 | 
						|
                for i in range(n):
 | 
						|
                    start_time = time.time()
 | 
						|
 | 
						|
                    def elapsed():
 | 
						|
                        return time.time() - start_time
 | 
						|
 | 
						|
                    try:
 | 
						|
                        test(server)
 | 
						|
                        success_times.append(elapsed())
 | 
						|
                        success_count += 1
 | 
						|
                        logger.info('success')
 | 
						|
                    except Exception as e:
 | 
						|
                        logger.error(f'failure: {e}')
 | 
						|
                        failure_count += 1
 | 
						|
                        failure_times.append(elapsed())
 | 
						|
                        failures.append(str(e))
 | 
						|
                        # import traceback
 | 
						|
                        # traceback.print_exc()
 | 
						|
                output_file.write(json.dumps({**output_kwargs, **dict(
 | 
						|
                    model=model,
 | 
						|
                    server_name=server_name,
 | 
						|
                    model_id=model_id,
 | 
						|
                    test=test_name,
 | 
						|
                    temp=t,
 | 
						|
                    top_p=top_p,
 | 
						|
                    top_k=top_k,
 | 
						|
                    ctk=ctk,
 | 
						|
                    ctv=ctv,
 | 
						|
                    seed=seed,
 | 
						|
                    success_ratio=float(success_count) / n,
 | 
						|
                    avg_time=mean(success_times + failure_times),
 | 
						|
                    median_time=median(success_times + failure_times),
 | 
						|
                    success_count=success_count,
 | 
						|
                    success_times=success_times,
 | 
						|
                    failure_count=failure_count,
 | 
						|
                    failure_times=failure_times,
 | 
						|
                    failures=list(set(failures)),
 | 
						|
                )}) + '\n')
 | 
						|
                output_file.flush()
 | 
						|
 | 
						|
        for t in [None] if temp is None else [t if t >= 0 else None for t in temp]:
 | 
						|
            if hf is not None:
 | 
						|
 | 
						|
                servers: list[Tuple[str, Optional[str]]] = [('llama-server', None)]
 | 
						|
                if llama_baseline is not None:
 | 
						|
                    servers.append(('llama-server (baseline)', llama_baseline))
 | 
						|
 | 
						|
                for server_name, server_path in servers:
 | 
						|
                    server = ServerProcess()
 | 
						|
                    server.n_ctx = n_ctx
 | 
						|
                    server.n_slots = 1
 | 
						|
                    server.jinja = True
 | 
						|
                    server.ctk = ctk
 | 
						|
                    server.ctv = ctv
 | 
						|
                    server.fa = fa
 | 
						|
                    server.n_predict = n_predict
 | 
						|
                    server.model_hf_repo = hf
 | 
						|
                    server.model_hf_file = None
 | 
						|
                    server.chat_template = chat_template
 | 
						|
                    server.server_path = server_path
 | 
						|
                    if port is not None:
 | 
						|
                        server.server_port = port
 | 
						|
                    # server.debug = True
 | 
						|
 | 
						|
                    with scoped_server(server):
 | 
						|
                        server.start(timeout_seconds=TIMEOUT_SERVER_START)
 | 
						|
                        for ignore_chat_grammar in [False]:
 | 
						|
                            run(
 | 
						|
                                server,
 | 
						|
                                server_name=server_name,
 | 
						|
                                model_id=hf,
 | 
						|
                                temp=t,
 | 
						|
                                output_kwargs=dict(
 | 
						|
                                    chat_template=chat_template,
 | 
						|
                                ),
 | 
						|
                                request_kwargs=dict(
 | 
						|
                                    ignore_chat_grammar=ignore_chat_grammar,
 | 
						|
                                ),
 | 
						|
                            )
 | 
						|
 | 
						|
            if ollama is not None:
 | 
						|
                server = ServerProcess()
 | 
						|
                server.server_port = 11434
 | 
						|
                server.server_host = "localhost"
 | 
						|
                subprocess.check_call(["ollama", "pull", ollama])
 | 
						|
 | 
						|
                with scoped_server(server):
 | 
						|
                    run(
 | 
						|
                        server,
 | 
						|
                        server_name="ollama",
 | 
						|
                        model_id=ollama,
 | 
						|
                        temp=t,
 | 
						|
                        output_kwargs=dict(
 | 
						|
                            chat_template=None,
 | 
						|
                        ),
 | 
						|
                        request_kwargs=dict(
 | 
						|
                            model=ollama,
 | 
						|
                            max_tokens=n_predict,
 | 
						|
                            num_ctx = n_ctx,
 | 
						|
                        ),
 | 
						|
                    )
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    app()
 |