From fb371c18ec5035a97f7ad2bd0cacb1de53d0b2c3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 28 Jul 2025 21:53:18 +0300 Subject: [PATCH] bench,common : add CPU extra buffer types --- common/arg.cpp | 19 +++++++++++++++++++ tools/llama-bench/llama-bench.cpp | 19 +++++++++++++++++++ 2 files changed, 38 insertions(+) diff --git a/common/arg.cpp b/common/arg.cpp index 060053595d..1cf417f0d9 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2347,6 +2347,25 @@ common_params_context common_params_parser_init(common_params & params, llama_ex buft_list[ggml_backend_buft_name(buft)] = buft; } } + + // add CPU extra buffer types + { + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (cpu_dev == nullptr) { + throw std::runtime_error("no CPU backend found"); + } + + auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); + auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) + ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); + if (ggml_backend_dev_get_extra_bufts_fn) { + ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev); + while (extra_bufts && *extra_bufts) { + buft_list[ggml_backend_buft_name(*extra_bufts)] = *extra_bufts; + ++extra_bufts; + } + } + } } for (const auto & override : string_split(value, ',')) { diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp index c56834a2a6..24faa2b47f 100644 --- a/tools/llama-bench/llama-bench.cpp +++ b/tools/llama-bench/llama-bench.cpp @@ -702,6 +702,25 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { buft_list[ggml_backend_buft_name(buft)] = buft; } } + + // add CPU extra buffer types + { + auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + if (cpu_dev == nullptr) { + throw std::runtime_error("no CPU backend found"); + } + + auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev); + auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t) + ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts"); + if (ggml_backend_dev_get_extra_bufts_fn) { + ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev); + while (extra_bufts && *extra_bufts) { + buft_list[ggml_backend_buft_name(*extra_bufts)] = *extra_bufts; + ++extra_bufts; + } + } + } } auto override_group_span_len = std::strcspn(value, ","); bool last_group = false;