From d8c17629ac18d21825aa91bf70cb82f3fa30c2a0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 30 Aug 2025 16:08:00 +0300 Subject: [PATCH] examples : add compare-mlx --- examples/compare-mlx/.gitignore | 2 + examples/compare-mlx/compare-mlx.sh | 706 ++++++++++++++++++++++++++ examples/compare-mlx/inspect_model.py | 120 +++++ examples/compare-mlx/mlx-ppl.py | 305 +++++++++++ 4 files changed, 1133 insertions(+) create mode 100644 examples/compare-mlx/.gitignore create mode 100755 examples/compare-mlx/compare-mlx.sh create mode 100644 examples/compare-mlx/inspect_model.py create mode 100644 examples/compare-mlx/mlx-ppl.py diff --git a/examples/compare-mlx/.gitignore b/examples/compare-mlx/.gitignore new file mode 100644 index 0000000000..cb7ebca84a --- /dev/null +++ b/examples/compare-mlx/.gitignore @@ -0,0 +1,2 @@ +*.txt +*/ diff --git a/examples/compare-mlx/compare-mlx.sh b/examples/compare-mlx/compare-mlx.sh new file mode 100755 index 0000000000..9c4b57690b --- /dev/null +++ b/examples/compare-mlx/compare-mlx.sh @@ -0,0 +1,706 @@ +#!/bin/bash + +# a script to compare MLX and GGUF models +# +# usage: +# ./compare-mlx.sh --raw-path wiki.test.raw --no-keep +# +# TODOs +# - add QAT evals + +# check if LLAMA_HOME_DIR is set +if [[ -z "$LLAMA_HOME_DIR" ]]; then + lcpp_dir=$(cd "$(dirname "${BASH_SOURCE[0]}")"/../../ && pwd) +else + lcpp_dir="${LLAMA_HOME_DIR}" +fi + +echo "Using llama.cpp directory: ${lcpp_dir}" + +# check for convert_hf_to_gguf.py +if [[ ! -f "${lcpp_dir}/convert_hf_to_gguf.py" ]]; then + echo "convert_hf_to_gguf.py not found in ${lcpp_dir}" + echo "Set a LLAMA_HOME_DIR environment variable to point to your llama.cpp directory" + exit 1 +fi + +set -x + +# sanity checks that all Python dependencies are installed +if ! python -c "import mlx.core"; then + echo "MLX not found. Please install MLX" + exit 1 +fi + +if ! python ${lcpp_dir}/convert_hf_to_gguf.py --help; then + echo "convert_hf_to_gguf.py not working. Please install llama.cpp python requirements" + exit 1 +fi + +# by default use the system binaries (for example from brew) +llama_perplexity="llama-perplexity" + +if [[ ! -z "$LLAMA_PERPLEXITY" ]]; then + llama_perplexity="$LLAMA_PERPLEXITY" +fi + +echo "Using llama-perplexity: ${llama_perplexity}" + +if ! command -v "$llama_perplexity" &> /dev/null; then + echo "llama-perplexity not found. Please install it." + exit 1 +fi + +llama_quantize="llama-quantize" + +if [[ ! -z "$LLAMA_QUANTIZE" ]]; then + llama_quantize="$LLAMA_QUANTIZE" +fi + +echo "Using llama-quantize: ${llama_quantize}" + +if ! command -v "$llama_quantize" &> /dev/null; then + echo "llama-quantize not found. Please install it." + exit 1 +fi + +llama_batched_bench="llama-batched-bench" + +if [[ ! -z "$LLAMA_BATCHED_BENCH" ]]; then + llama_batched_bench="$LLAMA_BATCHED_BENCH" +fi + +echo "Using llama-batched-bench: ${llama_batched_bench}" + +if ! command -v "$llama_batched_bench" &> /dev/null; then + echo "llama-batched-bench not found. Please install it." + exit 1 +fi + +# get the size in GiB +get_size() { + local path="$1" + local bytes=$(du -s "$path" | awk '{print $1}') + local res=$(echo "scale=3; ($bytes*512)/1024/1024/1024" | bc) + echo "$res" +} + +# parameters: +# --no-compute : do not compute anything, just summarize the existing results +# --no-ppl : do not compute ppl +# --no-perf : do not compute performance (speed) metrics +# --no-keep : delete intermediate model files +# --num-samples : number of text samples to evaluate (default: 512) +# --sequence-length : sequence length of the samples in tokens (default: 512) +# --raw-path : file with raw text (such as wikitext) + +# extra agruments +args_lcpp="-t 1" + +num_samples=512 +sequence_length=512 +raw_path="" +no_compute=false +no_ppl=false +no_perf=false +no_keep=false + +while [[ $# -gt 0 ]]; do + case $1 in + --no-compute) + no_compute=true + shift + ;; + --no-ppl) + no_ppl=true + shift + ;; + --no-perf) + no_perf=true + shift + ;; + --no-keep) + no_keep=true + shift + ;; + --num-samples) + num_samples="$2" + shift 2 + ;; + --sequence-length) + sequence_length="$2" + shift 2 + ;; + --raw-path) + raw_path="$2" + shift 2 + ;; + *) + echo "Unknown parameter: $1" + exit 1 + ;; + esac +done + +if [[ -z "$raw_path" ]]; then + echo "No raw path provided" + echo "Recommended to use the test set of WikiText from here: https://github.com/ggml-org/llama.cpp/blob/master/scripts/get-wikitext-2.sh" + exit 1 +fi + +eval_model() { + org="$1" + mid="$2" + + echo "Evaluating ${org}/${mid}" + + huggingface-cli download ${org}/${mid} --local-dir ${org}/${mid} + + # generate and process MLX models + + if [[ "$no_compute" == true ]]; then + echo "Skipping computation" + else + rm -rfv ./${mid}-f32-mlx + mlx_lm.convert --hf ./${org}/${mid} --mlx-path ./${mid}-f32-mlx --dtype float32 + get_size ./${mid}-f32-mlx > ./${mid}-f32-mlx-size.txt + + if [[ "$no_ppl" == false ]]; then + python ./mlx-ppl.py --model ./${mid}-f32-mlx --raw-path "$raw_path" --num-samples "$num_samples" --sequence-length "$sequence_length" 2>&1 | tee ./${mid}-f32-mlx-ppl.txt + fi + + # no need for F32 perf benchmarks + #if [[ "$no_perf" == false ]]; then + # mlx_lm.benchmark --model ./${mid}-f32-mlx -p 2048 -g 128 --num-trials 1 2>&1 | tee ./${mid}-f32-mlx-perf-2048.txt + # mlx_lm.benchmark --model ./${mid}-f32-mlx -p 4096 -g 128 --num-trials 1 2>&1 | tee ./${mid}-f32-mlx-perf-4096.txt + # mlx_lm.benchmark --model ./${mid}-f32-mlx -p 8192 -g 128 --num-trials 1 2>&1 | tee ./${mid}-f32-mlx-perf-8192.txt + # mlx_lm.benchmark --model ./${mid}-f32-mlx -p 16384 -g 128 --num-trials 1 2>&1 | tee ./${mid}-f32-mlx-perf-16384.txt + # mlx_lm.benchmark --model ./${mid}-f32-mlx -p 32768 -g 128 --num-trials 1 2>&1 | tee ./${mid}-f32-mlx-perf-32768.txt + #fi + + if [[ "$no_keep" == true ]]; then + echo "Deleting intermediate model files" + rm -rfv ./${mid}-f32-mlx + fi + + rm -rfv ./${mid}-bf16-mlx + mlx_lm.convert --hf ./${org}/${mid} --mlx-path ./${mid}-bf16-mlx --dtype bfloat16 + get_size ./${mid}-bf16-mlx > ./${mid}-bf16-mlx-size.txt + + if [[ "$no_ppl" == false ]]; then + python ./mlx-ppl.py --model ./${mid}-bf16-mlx --raw-path "$raw_path" --num-samples "$num_samples" --sequence-length "$sequence_length" 2>&1 | tee ./${mid}-bf16-mlx-ppl.txt + fi + + if [[ "$no_perf" == false ]]; then + mlx_lm.benchmark --model ./${mid}-bf16-mlx -p 2048 -g 128 --num-trials 1 2>&1 | tee ./${mid}-bf16-mlx-perf-2048.txt + mlx_lm.benchmark --model ./${mid}-bf16-mlx -p 4096 -g 128 --num-trials 1 2>&1 | tee ./${mid}-bf16-mlx-perf-4096.txt + mlx_lm.benchmark --model ./${mid}-bf16-mlx -p 8192 -g 128 --num-trials 1 2>&1 | tee ./${mid}-bf16-mlx-perf-8192.txt + mlx_lm.benchmark --model ./${mid}-bf16-mlx -p 16384 -g 128 --num-trials 1 2>&1 | tee ./${mid}-bf16-mlx-perf-16384.txt + mlx_lm.benchmark --model ./${mid}-bf16-mlx -p 32768 -g 128 --num-trials 1 2>&1 | tee ./${mid}-bf16-mlx-perf-32768.txt + fi + + if [[ "$no_keep" == true ]]; then + echo "Deleting intermediate model files" + rm -rfv ./${mid}-bf16-mlx + fi + + rm -rfv ./${mid}-f16-mlx + mlx_lm.convert --hf ./${org}/${mid} --mlx-path ./${mid}-f16-mlx --dtype float16 + get_size ./${mid}-f16-mlx > ./${mid}-f16-mlx-size.txt + + if [[ "$no_ppl" == false ]]; then + python ./mlx-ppl.py --model ./${mid}-f16-mlx --raw-path "$raw_path" --num-samples "$num_samples" --sequence-length "$sequence_length" 2>&1 | tee ./${mid}-f16-mlx-ppl.txt + fi + + if [[ "$no_perf" == false ]]; then + mlx_lm.benchmark --model ./${mid}-f16-mlx -p 2048 -g 128 --num-trials 1 2>&1 | tee ./${mid}-f16-mlx-perf-2048.txt + mlx_lm.benchmark --model ./${mid}-f16-mlx -p 4096 -g 128 --num-trials 1 2>&1 | tee ./${mid}-f16-mlx-perf-4096.txt + mlx_lm.benchmark --model ./${mid}-f16-mlx -p 8192 -g 128 --num-trials 1 2>&1 | tee ./${mid}-f16-mlx-perf-8192.txt + mlx_lm.benchmark --model ./${mid}-f16-mlx -p 16384 -g 128 --num-trials 1 2>&1 | tee ./${mid}-f16-mlx-perf-16384.txt + mlx_lm.benchmark --model ./${mid}-f16-mlx -p 32768 -g 128 --num-trials 1 2>&1 | tee ./${mid}-f16-mlx-perf-32768.txt + fi + + if [[ "$no_keep" == true ]]; then + echo "Deleting intermediate model files" + rm -rfv ./${mid}-f16-mlx + fi + + rm -rfv ./${mid}-q8-mlx + mlx_lm.convert --hf ./${org}/${mid} --mlx-path ./${mid}-q8-mlx --quantize --q-bits 8 + get_size ./${mid}-q8-mlx > ./${mid}-q8-mlx-size.txt + + if [[ "$no_ppl" == false ]]; then + python ./mlx-ppl.py --model ./${mid}-q8-mlx --raw-path "$raw_path" --num-samples "$num_samples" --sequence-length "$sequence_length" 2>&1 | tee ./${mid}-q8-mlx-ppl.txt + fi + + if [[ "$no_perf" == false ]]; then + mlx_lm.benchmark --model ./${mid}-q8-mlx -p 2048 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q8-mlx-perf-2048.txt + mlx_lm.benchmark --model ./${mid}-q8-mlx -p 4096 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q8-mlx-perf-4096.txt + mlx_lm.benchmark --model ./${mid}-q8-mlx -p 8192 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q8-mlx-perf-8192.txt + mlx_lm.benchmark --model ./${mid}-q8-mlx -p 16384 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q8-mlx-perf-16384.txt + mlx_lm.benchmark --model ./${mid}-q8-mlx -p 32768 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q8-mlx-perf-32768.txt + fi + + if [[ "$no_keep" == true ]]; then + echo "Deleting intermediate model files" + rm -rfv ./${mid}-q8-mlx + fi + + rm -rfv ./${mid}-q6-mlx + mlx_lm.convert --hf ./${org}/${mid} --mlx-path ./${mid}-q6-mlx --quantize --q-bits 6 + get_size ./${mid}-q6-mlx > ./${mid}-q6-mlx-size.txt + + if [[ "$no_ppl" == false ]]; then + python ./mlx-ppl.py --model ./${mid}-q6-mlx --raw-path "$raw_path" --num-samples "$num_samples" --sequence-length "$sequence_length" 2>&1 | tee ./${mid}-q6-mlx-ppl.txt + fi + + if [[ "$no_perf" == false ]]; then + mlx_lm.benchmark --model ./${mid}-q6-mlx -p 2048 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q6-mlx-perf-2048.txt + mlx_lm.benchmark --model ./${mid}-q6-mlx -p 4096 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q6-mlx-perf-4096.txt + mlx_lm.benchmark --model ./${mid}-q6-mlx -p 8192 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q6-mlx-perf-8192.txt + mlx_lm.benchmark --model ./${mid}-q6-mlx -p 16384 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q6-mlx-perf-16384.txt + mlx_lm.benchmark --model ./${mid}-q6-mlx -p 32768 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q6-mlx-perf-32768.txt + fi + + if [[ "$no_keep" == true ]]; then + echo "Deleting intermediate model files" + rm -rfv ./${mid}-q6-mlx + fi + + rm -rfv ./${mid}-q5-mlx + mlx_lm.convert --hf ./${org}/${mid} --mlx-path ./${mid}-q5-mlx --quantize --q-bits 5 + get_size ./${mid}-q5-mlx > ./${mid}-q5-mlx-size.txt + + if [[ "$no_ppl" == false ]]; then + python ./mlx-ppl.py --model ./${mid}-q5-mlx --raw-path "$raw_path" --num-samples "$num_samples" --sequence-length "$sequence_length" 2>&1 | tee ./${mid}-q5-mlx-ppl.txt + fi + + if [[ "$no_perf" == false ]]; then + mlx_lm.benchmark --model ./${mid}-q5-mlx -p 2048 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q5-mlx-perf-2048.txt + mlx_lm.benchmark --model ./${mid}-q5-mlx -p 4096 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q5-mlx-perf-4096.txt + mlx_lm.benchmark --model ./${mid}-q5-mlx -p 8192 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q5-mlx-perf-8192.txt + mlx_lm.benchmark --model ./${mid}-q5-mlx -p 16384 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q5-mlx-perf-16384.txt + mlx_lm.benchmark --model ./${mid}-q5-mlx -p 32768 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q5-mlx-perf-32768.txt + fi + + if [[ "$no_keep" == true ]]; then + echo "Deleting intermediate model files" + rm -rfv ./${mid}-q5-mlx + fi + + # I think this is something similar to q4_k + rm -rfv ./${mid}-q4p-mlx + mlx_lm.convert --hf ./${org}/${mid} --mlx-path ./${mid}-q4p-mlx --quantize --quant-predicate mixed_4_6 + get_size ./${mid}-q4p-mlx > ./${mid}-q4p-mlx-size.txt + + if [[ "$no_ppl" == false ]]; then + python ./mlx-ppl.py --model ./${mid}-q4p-mlx --raw-path "$raw_path" --num-samples "$num_samples" --sequence-length "$sequence_length" 2>&1 | tee ./${mid}-q4p-mlx-ppl.txt + fi + + if [[ "$no_perf" == false ]]; then + mlx_lm.benchmark --model ./${mid}-q4p-mlx -p 2048 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q4p-mlx-perf-2048.txt + mlx_lm.benchmark --model ./${mid}-q4p-mlx -p 4096 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q4p-mlx-perf-4096.txt + mlx_lm.benchmark --model ./${mid}-q4p-mlx -p 8192 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q4p-mlx-perf-8192.txt + mlx_lm.benchmark --model ./${mid}-q4p-mlx -p 16384 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q4p-mlx-perf-16384.txt + mlx_lm.benchmark --model ./${mid}-q4p-mlx -p 32768 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q4p-mlx-perf-32768.txt + fi + + if [[ "$no_keep" == true ]]; then + echo "Deleting intermediate model files" + rm -rfv ./${mid}-q4p-mlx + fi + + rm -rfv ./${mid}-q4-mlx + mlx_lm.convert --hf ./${org}/${mid} --mlx-path ./${mid}-q4-mlx --quantize --q-bits 4 + get_size ./${mid}-q4-mlx > ./${mid}-q4-mlx-size.txt + + if [[ "$no_ppl" == false ]]; then + python ./mlx-ppl.py --model ./${mid}-q4-mlx --raw-path "$raw_path" --num-samples "$num_samples" --sequence-length "$sequence_length" 2>&1 | tee ./${mid}-q4-mlx-ppl.txt + fi + + if [[ "$no_perf" == false ]]; then + mlx_lm.benchmark --model ./${mid}-q4-mlx -p 2048 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q4-mlx-perf-2048.txt + mlx_lm.benchmark --model ./${mid}-q4-mlx -p 4096 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q4-mlx-perf-4096.txt + mlx_lm.benchmark --model ./${mid}-q4-mlx -p 8192 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q4-mlx-perf-8192.txt + mlx_lm.benchmark --model ./${mid}-q4-mlx -p 16384 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q4-mlx-perf-16384.txt + mlx_lm.benchmark --model ./${mid}-q4-mlx -p 32768 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q4-mlx-perf-32768.txt + fi + + if [[ "$no_keep" == true ]]; then + echo "Deleting intermediate model files" + rm -rfv ./${mid}-q4-mlx + fi + + rm -rfv ./${mid}-q3-mlx + mlx_lm.convert --hf ./${org}/${mid} --mlx-path ./${mid}-q3-mlx --quantize --q-bits 3 + get_size ./${mid}-q3-mlx > ./${mid}-q3-mlx-size.txt + + if [[ "$no_ppl" == false ]]; then + python ./mlx-ppl.py --model ./${mid}-q3-mlx --raw-path "$raw_path" --num-samples "$num_samples" --sequence-length "$sequence_length" 2>&1 | tee ./${mid}-q3-mlx-ppl.txt + fi + + if [[ "$no_perf" == false ]]; then + mlx_lm.benchmark --model ./${mid}-q3-mlx -p 2048 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q3-mlx-perf-2048.txt + mlx_lm.benchmark --model ./${mid}-q3-mlx -p 4096 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q3-mlx-perf-4096.txt + mlx_lm.benchmark --model ./${mid}-q3-mlx -p 8192 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q3-mlx-perf-8192.txt + mlx_lm.benchmark --model ./${mid}-q3-mlx -p 16384 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q3-mlx-perf-16384.txt + mlx_lm.benchmark --model ./${mid}-q3-mlx -p 32768 -g 128 --num-trials 1 2>&1 | tee ./${mid}-q3-mlx-perf-32768.txt + fi + + if [[ "$no_keep" == true ]]; then + echo "Deleting intermediate model files" + rm -rfv ./${mid}-q3-mlx + fi + fi + + # generate and process llama.cpp GGUF models + + if [[ "$no_compute" == true ]]; then + echo "Skipping computation" + else + # the F32 model is the reference - we generate all other models from it + mkdir -p ./${mid}-f32-gguf + python ${lcpp_dir}/convert_hf_to_gguf.py ./${org}/${mid} --outtype f32 --outfile ./${mid}-f32-gguf/model.gguf + get_size ./${mid}-f32-gguf > ./${mid}-f32-gguf-size.txt + + if [[ "$no_ppl" == false ]]; then + ${llama_perplexity} $args_lcpp -m ./${mid}-f32-gguf/model.gguf -f "$raw_path" --chunks "${num_samples}" -c "${sequence_length}" 2>&1 | tee ./${mid}-f32-gguf-ppl.txt + fi + + # no need for F32 perf benchmarks + #if [[ "$no_perf" == false ]]; then + # ${llama_batched_bench} $args_lcpp -m ./${mid}-f32-gguf/model.gguf -c 33768 -b 2048 -ub 2048 -npp 2048,4096,8192,16384,32768 -ntg 128 -npl 1 2>&1 | tee ./${mid}-f32-gguf-perf.txt + #fi + + # this requires to explicitly build llama.cpp with BF16 support + rm -rfv ./${mid}-bf16-gguf && mkdir -p ./${mid}-bf16-gguf + ${llama_quantize} ./${mid}-f32-gguf/model.gguf ./${mid}-bf16-gguf/model.gguf bf16 + get_size ./${mid}-bf16-gguf > ./${mid}-bf16-gguf-size.txt + + if [[ "$no_ppl" == false ]]; then + ${llama_perplexity} $args_lcpp -m ./${mid}-bf16-gguf/model.gguf -f "$raw_path" --chunks "${num_samples}" -c "${sequence_length}" 2>&1 | tee ./${mid}-bf16-gguf-ppl.txt + fi + + if [[ "$no_perf" == false ]]; then + ${llama_batched_bench} $args_lcpp -m ./${mid}-bf16-gguf/model.gguf -c 33768 -b 2048 -ub 2048 -npp 2048,4096,8192,16384,32768 -ntg 128 -npl 1 2>&1 | tee ./${mid}-bf16-gguf-perf.txt + fi + + if [[ "$no_keep" == true ]]; then + echo "Deleting intermediate model files" + rm -rfv ./${mid}-bf16-gguf + fi + + rm -rfv ./${mid}-f16-gguf && mkdir -p ./${mid}-f16-gguf + ${llama_quantize} ./${mid}-f32-gguf/model.gguf ./${mid}-f16-gguf/model.gguf f16 + get_size ./${mid}-f16-gguf > ./${mid}-f16-gguf-size.txt + + if [[ "$no_ppl" == false ]]; then + ${llama_perplexity} $args_lcpp -m ./${mid}-f16-gguf/model.gguf -f "$raw_path" --chunks "${num_samples}" -c "${sequence_length}" 2>&1 | tee ./${mid}-f16-gguf-ppl.txt + fi + + if [[ "$no_perf" == false ]]; then + ${llama_batched_bench} $args_lcpp -m ./${mid}-f16-gguf/model.gguf -c 33768 -b 2048 -ub 2048 -npp 2048,4096,8192,16384,32768 -ntg 128 -npl 1 2>&1 | tee ./${mid}-f16-gguf-perf.txt + fi + + if [[ "$no_keep" == true ]]; then + echo "Deleting intermediate model files" + rm -rfv ./${mid}-f16-gguf + fi + + rm -rfv ./${mid}-q8-gguf && mkdir -p ./${mid}-q8-gguf + ${llama_quantize} ./${mid}-f32-gguf/model.gguf ./${mid}-q8-gguf/model.gguf q8_0 + get_size ./${mid}-q8-gguf > ./${mid}-q8-gguf-size.txt + + if [[ "$no_ppl" == false ]]; then + ${llama_perplexity} $args_lcpp -m ./${mid}-q8-gguf/model.gguf -f "$raw_path" --chunks "${num_samples}" -c "${sequence_length}" 2>&1 | tee ./${mid}-q8-gguf-ppl.txt + fi + + if [[ "$no_perf" == false ]]; then + ${llama_batched_bench} $args_lcpp -m ./${mid}-q8-gguf/model.gguf -c 33768 -b 2048 -ub 2048 -npp 2048,4096,8192,16384,32768 -ntg 128 -npl 1 2>&1 | tee ./${mid}-q8-gguf-perf.txt + fi + + if [[ "$no_keep" == true ]]; then + echo "Deleting intermediate model files" + rm -rfv ./${mid}-q8-gguf + fi + + rm -rfv ./${mid}-q6-gguf && mkdir -p ./${mid}-q6-gguf + ${llama_quantize} ./${mid}-f32-gguf/model.gguf ./${mid}-q6-gguf/model.gguf q6_k + get_size ./${mid}-q6-gguf > ./${mid}-q6-gguf-size.txt + + if [[ "$no_ppl" == false ]]; then + ${llama_perplexity} $args_lcpp -m ./${mid}-q6-gguf/model.gguf -f "$raw_path" --chunks "${num_samples}" -c "${sequence_length}" 2>&1 | tee ./${mid}-q6-gguf-ppl.txt + fi + + if [[ "$no_perf" == false ]]; then + ${llama_batched_bench} $args_lcpp -m ./${mid}-q6-gguf/model.gguf -c 33768 -b 2048 -ub 2048 -npp 2048,4096,8192,16384,32768 -ntg 128 -npl 1 2>&1 | tee ./${mid}-q6-gguf-perf.txt + fi + + if [[ "$no_keep" == true ]]; then + echo "Deleting intermediate model files" + rm -rfv ./${mid}-q6-gguf + fi + + rm -rfv ./${mid}-q5-gguf && mkdir -p ./${mid}-q5-gguf + ${llama_quantize} ./${mid}-f32-gguf/model.gguf ./${mid}-q5-gguf/model.gguf q5_k_s + get_size ./${mid}-q5-gguf > ./${mid}-q5-gguf-size.txt + + if [[ "$no_ppl" == false ]]; then + ${llama_perplexity} $args_lcpp -m ./${mid}-q5-gguf/model.gguf -f "$raw_path" --chunks "${num_samples}" -c "${sequence_length}" 2>&1 | tee ./${mid}-q5-gguf-ppl.txt + fi + + if [[ "$no_perf" == false ]]; then + ${llama_batched_bench} $args_lcpp -m ./${mid}-q5-gguf/model.gguf -c 33768 -b 2048 -ub 2048 -npp 2048,4096,8192,16384,32768 -ntg 128 -npl 1 2>&1 | tee ./${mid}-q5-gguf-perf.txt + fi + + if [[ "$no_keep" == true ]]; then + echo "Deleting intermediate model files" + rm -rfv ./${mid}-q5-gguf + fi + + rm -rfv ./${mid}-q4p-gguf && mkdir -p ./${mid}-q4p-gguf + ${llama_quantize} ./${mid}-f32-gguf/model.gguf ./${mid}-q4p-gguf/model.gguf q4_k + get_size ./${mid}-q4p-gguf > ./${mid}-q4p-gguf-size.txt + + if [[ "$no_ppl" == false ]]; then + ${llama_perplexity} $args_lcpp -m ./${mid}-q4p-gguf/model.gguf -f "$raw_path" --chunks "${num_samples}" -c "${sequence_length}" 2>&1 | tee ./${mid}-q4p-gguf-ppl.txt + fi + + if [[ "$no_perf" == false ]]; then + ${llama_batched_bench} $args_lcpp -m ./${mid}-q4p-gguf/model.gguf -c 33768 -b 2048 -ub 2048 -npp 2048,4096,8192,16384,32768 -ntg 128 -npl 1 2>&1 | tee ./${mid}-q4p-gguf-perf.txt + fi + + if [[ "$no_keep" == true ]]; then + echo "Deleting intermediate model files" + rm -rfv ./${mid}-q4p-gguf + fi + + # note: we use --pure here to match the MLX quantization of the embeddings + rm -rfv ./${mid}-q4-gguf && mkdir -p ./${mid}-q4-gguf + ${llama_quantize} --pure ./${mid}-f32-gguf/model.gguf ./${mid}-q4-gguf/model.gguf q4_0 + get_size ./${mid}-q4-gguf > ./${mid}-q4-gguf-size.txt + + if [[ "$no_ppl" == false ]]; then + ${llama_perplexity} $args_lcpp -m ./${mid}-q4-gguf/model.gguf -f "$raw_path" --chunks "${num_samples}" -c "${sequence_length}" 2>&1 | tee ./${mid}-q4-gguf-ppl.txt + fi + + if [[ "$no_perf" == false ]]; then + ${llama_batched_bench} $args_lcpp -m ./${mid}-q4-gguf/model.gguf -c 33768 -b 2048 -ub 2048 -npp 2048,4096,8192,16384,32768 -ntg 128 -npl 1 2>&1 | tee ./${mid}-q4-gguf-perf.txt + fi + + if [[ "$no_keep" == true ]]; then + echo "Deleting intermediate model files" + rm -rfv ./${mid}-q4-gguf + fi + + rm -rfv ./${mid}-q3-gguf && mkdir -p ./${mid}-q3-gguf + ${llama_quantize} ./${mid}-f32-gguf/model.gguf ./${mid}-q3-gguf/model.gguf q3_k_s + get_size ./${mid}-q3-gguf > ./${mid}-q3-gguf-size.txt + + if [[ "$no_ppl" == false ]]; then + ${llama_perplexity} $args_lcpp -m ./${mid}-q3-gguf/model.gguf -f "$raw_path" --chunks "${num_samples}" -c "${sequence_length}" 2>&1 | tee ./${mid}-q3-gguf-ppl.txt + fi + + if [[ "$no_perf" == false ]]; then + ${llama_batched_bench} $args_lcpp -m ./${mid}-q3-gguf/model.gguf -c 33768 -b 2048 -ub 2048 -npp 2048,4096,8192,16384,32768 -ntg 128 -npl 1 2>&1 | tee ./${mid}-q3-gguf-perf.txt + fi + + if [[ "$no_keep" == true ]]; then + echo "Deleting intermediate model files" + rm -rfv ./${mid}-q3-gguf + fi + + # remove the f32 model at the end + if [[ "$no_keep" == true ]]; then + rm -rfv ./${mid}-f32-gguf + fi + fi + + set +x + + # analyze results + + types=("f32" "bf16" "f16" "q8" "q6" "q5" "q4p" "q4" "q3") + + mlx_ppls=() + mlx_ppl_deltas=() + mlx_sizes=() + mlx_pps2k=() + mlx_tgs2k=() + mlx_pps4k=() + mlx_tgs4k=() + mlx_pps8k=() + mlx_tgs8k=() + mlx_pps16k=() + mlx_tgs16k=() + mlx_pps32k=() + mlx_tgs32k=() + + # mlx: + for t in ${types[*]}; do + cur_ppl="N/A" + cur_ppl_delta="N/A" + cur_size="N/A" + cur_pp2k="N/A" + cur_tg2k="N/A" + cur_pp4k="N/A" + cur_tg4k="N/A" + cur_pp8k="N/A" + cur_tg8k="N/A" + cur_pp16k="N/A" + cur_tg16k="N/A" + cur_pp32k="N/A" + cur_tg32k="N/A" + + if [[ -f ./${mid}-${t}-mlx-ppl.txt ]]; then + cur_ppl=$(grep -o 'Perplexity: [0-9.]*' ./${mid}-${t}-mlx-ppl.txt | cut -d' ' -f2) + cur_ppl_delta=$(grep -o 'Perplexity: [0-9.]* ± [0-9.]*' ./${mid}-${t}-mlx-ppl.txt | cut -d' ' -f4) + cur_size=$(cat ./${mid}-${t}-mlx-size.txt) + cur_pp2k=$(grep -o 'Averages.*prompt_tps=[0-9.]*' ./${mid}-${t}-mlx-perf-2048.txt | cut -d'=' -f2) + cur_tg2k=$(grep -o 'Averages.*generation_tps=[0-9.]*' ./${mid}-${t}-mlx-perf-2048.txt | cut -d'=' -f3) + cur_pp4k=$(grep -o 'Averages.*prompt_tps=[0-9.]*' ./${mid}-${t}-mlx-perf-4096.txt | cut -d'=' -f2) + cur_tg4k=$(grep -o 'Averages.*generation_tps=[0-9.]*' ./${mid}-${t}-mlx-perf-4096.txt | cut -d'=' -f3) + cur_pp8k=$(grep -o 'Averages.*prompt_tps=[0-9.]*' ./${mid}-${t}-mlx-perf-8192.txt | cut -d'=' -f2) + cur_tg8k=$(grep -o 'Averages.*generation_tps=[0-9.]*' ./${mid}-${t}-mlx-perf-8192.txt | cut -d'=' -f3) + cur_pp16k=$(grep -o 'Averages.*prompt_tps=[0-9.]*' ./${mid}-${t}-mlx-perf-16384.txt | cut -d'=' -f2) + cur_tg16k=$(grep -o 'Averages.*generation_tps=[0-9.]*' ./${mid}-${t}-mlx-perf-16384.txt | cut -d'=' -f3) + cur_pp32k=$(grep -o 'Averages.*prompt_tps=[0-9.]*' ./${mid}-${t}-mlx-perf-32768.txt | cut -d'=' -f2) + cur_tg32k=$(grep -o 'Averages.*generation_tps=[0-9.]*' ./${mid}-${t}-mlx-perf-32768.txt | cut -d'=' -f3) + fi + + mlx_ppls+=("${cur_ppl}") + mlx_ppl_deltas+=("${cur_ppl_delta}") + mlx_sizes+=("${cur_size}") + mlx_pps2k+=("${cur_pp2k}") + mlx_tgs2k+=("${cur_tg2k}") + mlx_pps4k+=("${cur_pp4k}") + mlx_tgs4k+=("${cur_tg4k}") + mlx_pps8k+=("${cur_pp8k}") + mlx_tgs8k+=("${cur_tg8k}") + mlx_pps16k+=("${cur_pp16k}") + mlx_tgs16k+=("${cur_tg16k}") + mlx_pps32k+=("${cur_pp32k}") + mlx_tgs32k+=("${cur_tg32k}") + done + + gguf_ppls=() + gguf_ppl_deltas=() + gguf_sizes=() + gguf_pps2k=() + gguf_tgs2k=() + gguf_pps4k=() + gguf_tgs4k=() + gguf_pps8k=() + gguf_tgs8k=() + gguf_pps16k=() + gguf_tgs16k=() + gguf_pps32k=() + gguf_tgs32k=() + + # gguf: + for t in ${types[*]}; do + cur_ppl="N/A" + cur_ppl_delta="N/A" + cur_size="N/A" + cur_pp2k="N/A" + cur_tg2k="N/A" + cur_pp4k="N/A" + cur_tg4k="N/A" + cur_pp8k="N/A" + cur_tg8k="N/A" + cur_pp16k="N/A" + cur_tg16k="N/A" + cur_pp32k="N/A" + cur_tg32k="N/A" + + if [[ -f ./${mid}-${t}-gguf-ppl.txt ]]; then + cur_ppl=$(grep -o 'Final estimate: PPL = [0-9.]*' ./${mid}-${t}-gguf-ppl.txt | sed -e "s/.*Final//" | cut -d' ' -f5) + cur_ppl_delta=$(grep -o 'Final estimate: PPL = [0-9.]* +/- [0-9.]*' ./${mid}-${t}-gguf-ppl.txt | sed -e "s/.*Final//" | cut -d' ' -f7) + cur_size=$(cat ./${mid}-${t}-gguf-size.txt) + cur_pp2k=$(grep -o '| 2048 |.*' ./${mid}-${t}-gguf-perf.txt | awk '{print $12}') + cur_tg2k=$(grep -o '| 2048 |.*' ./${mid}-${t}-gguf-perf.txt | awk '{print $16}') + cur_pp4k=$(grep -o '| 4096 |.*' ./${mid}-${t}-gguf-perf.txt | awk '{print $12}') + cur_tg4k=$(grep -o '| 4096 |.*' ./${mid}-${t}-gguf-perf.txt | awk '{print $16}') + cur_pp8k=$(grep -o '| 8192 |.*' ./${mid}-${t}-gguf-perf.txt | awk '{print $12}') + cur_tg8k=$(grep -o '| 8192 |.*' ./${mid}-${t}-gguf-perf.txt | awk '{print $16}') + cur_pp16k=$(grep -o '| 16384 |.*' ./${mid}-${t}-gguf-perf.txt | awk '{print $12}') + cur_tg16k=$(grep -o '| 16384 |.*' ./${mid}-${t}-gguf-perf.txt | awk '{print $16}') + cur_pp32k=$(grep -o '| 32768 |.*' ./${mid}-${t}-gguf-perf.txt | awk '{print $12}') + cur_tg32k=$(grep -o '| 32768 |.*' ./${mid}-${t}-gguf-perf.txt | awk '{print $16}') + fi + + gguf_ppls+=("${cur_ppl}") + gguf_ppl_deltas+=("${cur_ppl_delta}") + gguf_sizes+=("${cur_size}") + gguf_pps2k+=("${cur_pp2k}") + gguf_tgs2k+=("${cur_tg2k}") + gguf_pps4k+=("${cur_pp4k}") + gguf_tgs4k+=("${cur_tg4k}") + gguf_pps8k+=("${cur_pp8k}") + gguf_tgs8k+=("${cur_tg8k}") + gguf_pps16k+=("${cur_pp16k}") + gguf_tgs16k+=("${cur_tg16k}") + gguf_pps32k+=("${cur_pp32k}") + gguf_tgs32k+=("${cur_tg32k}") + done + + res="${mid}-results.txt" + echo "Results for ${org}/${mid} saved to ${res}" + + printf "\n" | tee ${res} + printf "Model ID: ${org}/${mid}\n" | tee -a ${res} + #printf "Samples: ${num_samples}\n" | tee -a ${res} + #printf "Sequence Length: ${sequence_length}\n" | tee -a ${res} + printf "\n" | tee -a ${res} + printf "| Type | MLX PPL | GGUF PPL | MLX Size | GGUF Size | MLX PP 2K | GGUF PP 2K | MLX TG 2K | GGUF TG 2K | MLX PP 4K | GGUF PP 4K | MLX TG 4K | GGUF TG 4K | MLX PP 8K | GGUF PP 8K | MLX TG 8K | GGUF TG 8K | MLX PP 16K | GGUF PP 16K | MLX TG 16K | GGUF TG 16K | MLX PP 32K | GGUF PP 32K | MLX TG 32K | GGUF TG 32K |\n" | tee -a ${res} + printf "|-------|---------------------|------------------------|----------|-----------| ---------- | ----------- | ---------- | ----------- | ---------- | ----------- | ---------- | ----------- | ---------- | ----------- | ---------- | ----------- | ---------- | ----------- | ---------- | ----------- | ---------- | ----------- | ---------- | ----------- |\n" | tee -a ${res} + + for i in "${!types[@]}"; do + printf "| %-5s | %10s ± %6s | %10s ± %9s | %8s | %9s | %10s | %11s | %10s | %11s | %10s | %11s | %10s | %11s | %10s | %11s | %10s | %11s | %10s | %11s | %10s | %11s | %10s | %11s | %10s | %11s |\n" \ + "${types[i]}" \ + "${mlx_ppls[i]}" \ + "${mlx_ppl_deltas[i]}" \ + "${gguf_ppls[i]}" \ + "${gguf_ppl_deltas[i]}" \ + "${mlx_sizes[i]}" \ + "${gguf_sizes[i]}" \ + "${mlx_pps2k[i]}" \ + "${gguf_pps2k[i]}" \ + "${mlx_tgs2k[i]}" \ + "${gguf_tgs2k[i]}" \ + "${mlx_pps4k[i]}" \ + "${gguf_pps4k[i]}" \ + "${mlx_tgs4k[i]}" \ + "${gguf_tgs4k[i]}" \ + "${mlx_pps8k[i]}" \ + "${gguf_pps8k[i]}" \ + "${mlx_tgs8k[i]}" \ + "${gguf_tgs8k[i]}" \ + "${mlx_pps16k[i]}" \ + "${gguf_pps16k[i]}" \ + "${mlx_tgs16k[i]}" \ + "${gguf_tgs16k[i]}" \ + "${mlx_pps32k[i]}" \ + "${gguf_pps32k[i]}" \ + "${mlx_tgs32k[i]}" \ + "${gguf_tgs32k[i]}" | tee -a ${res} + done +} + +eval_model "meta-llama" "Llama-3.2-1B" +eval_model "meta-llama" "Llama-3.2-3B" +eval_model "meta-llama" "Llama-3.1-8B" + +eval_model "google" "gemma-3-270m" +eval_model "google" "gemma-3-1b-pt" +eval_model "google" "gemma-3-4b-pt" + +# the mlx-ppl.y script does not work with these models - not sure why +#eval_model "google" "gemma-3n-E2B" +#eval_model "google" "gemma-3n-E4B" + +eval_model "Qwen" "Qwen3-0.6B-Base" +eval_model "Qwen" "Qwen3-1.7B-Base" +eval_model "Qwen" "Qwen3-4B-Base" +eval_model "Qwen" "Qwen3-8B-Base" +eval_model "Qwen" "Qwen3-30B-A3B-Base" diff --git a/examples/compare-mlx/inspect_model.py b/examples/compare-mlx/inspect_model.py new file mode 100644 index 0000000000..cfd93c395d --- /dev/null +++ b/examples/compare-mlx/inspect_model.py @@ -0,0 +1,120 @@ +#!/usr/bin/env python3 +# generated by Claude +""" +Script to inspect SafeTensors model files and print tensor information. +""" + +import json +from safetensors import safe_open +import os +from pathlib import Path + +def inspect_safetensors_model(model_dir="."): + """Inspect all SafeTensors files in the model directory.""" + + # First, let's read the index file to see the file structure + index_file = Path(model_dir) / "model.safetensors.index.json" + + if index_file.exists(): + with open(index_file, 'r') as f: + index_data = json.load(f) + + print("=== Model Structure ===") + print(f"Total parameters: {index_data.get('metadata', {}).get('total_size', 'Unknown')}") + print() + + # Get all safetensor files + safetensor_files = set(index_data.get('weight_map', {}).values()) + else: + # If no index file, look for safetensor files directly + safetensor_files = [f for f in os.listdir(model_dir) if f.endswith('.safetensors')] + + # Sort files for consistent output + safetensor_files = sorted(safetensor_files) + + print("=== Tensor Information ===") + print(f"{'Tensor Name':<50} {'Shape':<25} {'Data Type':<15} {'File'}") + print("-" * 110) + + total_tensors = 0 + + for filename in safetensor_files: + filepath = Path(model_dir) / filename + if not filepath.exists(): + continue + + print(f"\n--- {filename} ---") + + # Open and inspect the safetensor file + with safe_open(filepath, framework="pt") as f: # Use PyTorch framework for better dtype support + tensor_names = f.keys() + + for tensor_name in sorted(tensor_names): + # Get tensor metadata without loading the full tensor + tensor_slice = f.get_slice(tensor_name) + shape = tensor_slice.get_shape() + dtype = tensor_slice.get_dtype() + + shape_str = str(tuple(shape)) + dtype_str = str(dtype) + + print(f"{tensor_name:<50} {shape_str:<25} {dtype_str:<15} {filename}") + total_tensors += 1 + + print(f"\nTotal tensors found: {total_tensors}") + +def main(): + import argparse + + parser = argparse.ArgumentParser(description="Inspect SafeTensors model files") + parser.add_argument("--model-dir", "-d", default=".", + help="Directory containing the model files (default: current directory)") + parser.add_argument("--summary", "-s", action="store_true", + help="Show only summary statistics") + + args = parser.parse_args() + + if args.summary: + print_summary_only(args.model_dir) + else: + inspect_safetensors_model(args.model_dir) + +def print_summary_only(model_dir="."): + """Print only summary statistics.""" + safetensor_files = [f for f in os.listdir(model_dir) if f.endswith('.safetensors')] + + total_tensors = 0 + dtype_counts = {} + total_params = 0 + + for filename in sorted(safetensor_files): + filepath = Path(model_dir) / filename + if not filepath.exists(): + continue + + with safe_open(filepath, framework="pt") as f: # Use PyTorch framework + for tensor_name in f.keys(): + tensor_slice = f.get_slice(tensor_name) + shape = tensor_slice.get_shape() + dtype = tensor_slice.get_dtype() + + total_tensors += 1 + + dtype_str = str(dtype) + dtype_counts[dtype_str] = dtype_counts.get(dtype_str, 0) + 1 + + # Calculate parameter count + param_count = 1 + for dim in shape: + param_count *= dim + total_params += param_count + + print("=== Model Summary ===") + print(f"Total tensors: {total_tensors}") + print(f"Total parameters: {total_params:,}") + print(f"Data type distribution:") + for dtype, count in sorted(dtype_counts.items()): + print(f" {dtype}: {count} tensors") + +if __name__ == "__main__": + main() diff --git a/examples/compare-mlx/mlx-ppl.py b/examples/compare-mlx/mlx-ppl.py new file mode 100644 index 0000000000..4752839c50 --- /dev/null +++ b/examples/compare-mlx/mlx-ppl.py @@ -0,0 +1,305 @@ +# Copyright © 2025 Apple Inc. +# modified: https://github.com/ml-explore/mlx-lm/blob/60320dc2347d45dc3ca08be90e5255fb9424bb09/mlx_lm/perplexity.py +""" +Evaluate perplexity (PPL) of pre-trained MLX models in the same way as llama.cpp's llama-perplexity. +""" + +import argparse +import math +import os +import time +import types + +import mlx.core as mx +import mlx.nn as nn +import numpy as np + +from mlx_lm.tuner.datasets import load_dataset +from mlx_lm.tuner.utils import get_total_parameters +from mlx_lm.utils import load + + +def load_data( + tokenizer, + data_path: str, + num_samples: int, + sequence_length: int, +): + """ + Load a Hugging‑Face dataset (via mlx‑lm’s dataset utilities) and convert it + into a token tensor of shape (N, sequence_length). + """ + args = types.SimpleNamespace( + hf_dataset={ + "path": data_path, + "train_split": "train", + "valid_split": "train[:1]", + }, + train=True, + test=False, + ) + dataset = load_dataset(args, tokenizer)[0] + + perm = np.random.permutation(len(dataset)).tolist() + + num_tokens = sequence_length * num_samples if num_samples > 0 else float("inf") + data = [] + i = 0 + while len(data) < num_tokens: + tokens, _ = dataset.process(dataset[perm[i]]) + i += 1 + data.extend(tokens) + + # Convert to MX array, truncate to a multiple of `sequence_length` + data = mx.array(data[: (len(data) // sequence_length) * sequence_length]) + data = data.reshape(-1, sequence_length) + if num_samples > 0: + data = data[:num_samples] + return data + + +def _tokenize_text(tokenizer, text: str): + """ + Helper that tokenises a string using the MLX‑LM tokenizer. + Supports the common `encode` method or a callable tokenizer. + """ + # Most mlx‑lm tokenizers expose an `encode` method. + if hasattr(tokenizer, "encode"): + tokens = tokenizer.encode(text) + elif callable(tokenizer): + tokens = tokenizer(text) + else: + raise AttributeError( + "Tokenizer does not have an `encode` method nor is it callable." + ) + # Normalise the output to a Python list of ints. + if isinstance(tokens, mx.array): + tokens = tokens.tolist() + return tokens + + +# load a raw text file and tokenize it +# generated with gpt-oss-120b +def load_raw_data( + tokenizer, + raw_path: str, + num_samples: int, + sequence_length: int, +): + """ + Load a raw text file, tokenize it, and reshape into a (N, sequence_length) + tensor suitable for perplexity evaluation. + """ + if not os.path.isfile(raw_path): + raise FileNotFoundError(f"Raw text file not found: {raw_path}") + + # Read the whole file (UTF‑8). Users can supply any plain‑text corpus. + with open(raw_path, "r", encoding="utf-8") as fp: + raw_text = fp.read() + + # Tokenise the complete text. + token_list = _tokenize_text(tokenizer, raw_text) + + if len(token_list) == 0: + raise ValueError("Tokenisation of the raw file produced no tokens.") + + # Convert to MX array (int32 is sufficient for token IDs). + token_array = mx.array(token_list, dtype=mx.int32) + + # Trim to a length that is an exact multiple of `sequence_length`. + total_len = (token_array.shape[0] // sequence_length) * sequence_length + token_array = token_array[:total_len] + + # Reshape into (num_sequences, sequence_length) + data = token_array.reshape(-1, sequence_length) + + if num_samples > 0: + data = data[:num_samples] + + #print(f"First 4 samples of the data:") + #for j in range(min(4, len(data))): + # print(f" Sample {j}: {tokenizer.decode(data[j].tolist())}\n\n-------------------\n\n") + + return data + + +def eval_ppl(model, tokenizer, data, batch_size=8): + """ + Evaluate perplexity on a dataset with standard error calculation. + + Args: + model: The model to evaluate. + data: Tokenized data tensor (shape: N x L). + batch_size: Batch size for evaluation. + + Returns: + tuple: (perplexity, standard_error_of_perplexity) + """ + all_losses = [] + + num_batches = (len(data) + batch_size - 1) // batch_size + for i, s in enumerate(range(0, len(data), batch_size)): + batch = data[s : s + batch_size] + + # Set the first token of all samples to the BOS token + if tokenizer.bos_token_id: + batch[:, 0] = tokenizer.bos_token_id + + # compute cross entropy only with the second half of the sequence to match llama.cpp behavior + # ref: https://github.com/ggml-org/llama.cpp/blob/696fccf354e9dbdfbce135bc40b44c9dcc64dda9/tools/perplexity/perplexity.cpp#L527-L541 + # + #start = 0 + start = batch.shape[1] // 2 + + # Forward pass: get logits for all tokens except last + logits = model(batch[:, :-1]).astype(mx.float32) + + # Calculate cross‑entropy loss with next tokens + #losses = nn.losses.cross_entropy(logits, batch[:, 1:], reduction="none") + losses = nn.losses.cross_entropy(logits[:, start:, :], batch[:, start+1:], reduction="none") + + mx.eval(losses) + # Store individual token losses + all_losses.append(losses.flatten()) + + # Progress indicator + if (i + 1) % 1 == 0 or (i + 1) == num_batches: + print(f" Processed {i + 1}/{num_batches} batches...", end="\r") + + print() # New line after progress + + # Concatenate all losses into a single array + all_losses = mx.concatenate(all_losses) + + # Calculate mean loss and perplexity + mean_loss = all_losses.mean().item() + ppl = math.exp(mean_loss) + + # Calculate standard error + std_dev = mx.sqrt(mx.var(all_losses, ddof=1)).item() + num_tokens = all_losses.size + standard_error = std_dev / math.sqrt(num_tokens) + + # Delta approximation for standard error of perplexity + standard_error_ppl = ppl * standard_error + + return ppl, standard_error_ppl + + +def main(): + parser = argparse.ArgumentParser(description="Evaluate perplexity of MLX models") + parser.add_argument( + "--model", + type=str, + required=True, + help="Path to model or Hugging Face model ID", + ) + parser.add_argument( + "--batch-size", type=int, default=8, help="Batch size for evaluation" + ) + parser.add_argument( + "--sequence-length", + type=int, + default=512, + help="Sequence length for evaluation", + ) + parser.add_argument( + "--num-samples", + type=int, + default=256, + help="Number of samples to use (-1 for all available)", + ) + parser.add_argument( + "--data-path", + type=str, + default="allenai/tulu-3-sft-mixture", + help=( + "A Hugging Face dataset compatible with mlx‑lm. " + "Ignored if --raw-path is provided." + ), + ) + parser.add_argument( + "--raw-path", + type=str, + default=None, + help=( + "Path to a local raw‑text file to use for evaluation. " + "If specified, the script skips loading a HF dataset." + ), + ) + parser.add_argument( + "--seed", type=int, default=123, help="Random seed for data sampling" + ) + + args = parser.parse_args() + + # Set random seed (used for HF dataset shuffling) + mx.random.seed(args.seed) + + # Load model + print(f"Loading model from {args.model}...") + model, tokenizer = load(args.model) + + # Count parameters + total_params = get_total_parameters(model) + print(f"Model loaded: {total_params/1e6:.1f}M parameters") + + # ---------------------------------------------------------------------- + # Load evaluation data (raw file vs. HF dataset) + # ---------------------------------------------------------------------- + print("\nLoading dataset...") + print(f" Sequence length: {args.sequence_length}") + + if args.raw_path: + print(f" Using raw text file: {args.raw_path}") + data = load_raw_data( + tokenizer, + raw_path=args.raw_path, + num_samples=args.num_samples, + sequence_length=args.sequence_length, + ) + else: + print(f" Using HF dataset: {args.data_path}") + data = load_data( + tokenizer, + data_path=args.data_path, + num_samples=args.num_samples, + sequence_length=args.sequence_length, + ) + + print(f" Loaded {len(data)} samples") + + # ---------------------------------------------------------------------- + # Evaluate perplexity + # ---------------------------------------------------------------------- + print(f"\nEvaluating perplexity with batch size {args.batch_size}...") + start_time = time.time() + + ppl, se = eval_ppl(model, tokenizer, data, batch_size=args.batch_size) + + eval_time = time.time() - start_time + tokens_evaluated = data.shape[0] * (data.shape[1] - 1) # B * (L - 1) + + # Print results + print("\n" + "=" * 60) + print("EVALUATION RESULTS") + print("=" * 60) + print(f"Model: {args.model}") + print(f"Perplexity: {ppl:.3f} ± {se:.3f}") + print(f"Evaluation time: {eval_time:.2f} seconds") + print(f"Peak memory: {mx.get_peak_memory() / 1e9:.2f} GB") + print(f"Tokens per second: {tokens_evaluated / eval_time:.0f}") + + # Additional statistics + print(f"\nDataset statistics:") + print(f" Total samples: {len(data)}") + print(f" Total tokens: {data.size}") + + # ---------------------------------------------------------------------- + # Done + # ---------------------------------------------------------------------- + + +if __name__ == "__main__": + main() +