llama : separate compute buffer reserve from fattn check (#15696)

Exposes ggml_backend_sched_split_graph() to allow splitting the graph without allocating compute buffers and uses it to split the graph for the automatic Flash Attention check.
This commit is contained in:
Diego Devesa
2025-08-31 06:49:03 -07:00
committed by GitHub
parent 7d3c9f2b21
commit 9777032dcc
4 changed files with 64 additions and 58 deletions

View File

@@ -270,8 +270,60 @@ llama_context::llama_context(
}
}
// resolve automatic Flash Attention use and reserve worst-case graph
if (!hparams.vocab_only) {
llama_memory_context_ptr mctx;
if (memory) {
LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
mctx = memory->init_full();
if (!mctx) {
throw std::runtime_error("failed to initialize memory module");
}
}
cross.v_embd.clear();
// resolve automatic Flash Attention use
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
auto * gf = graph_reserve(1, 1, 0, mctx.get(), true);
if (!gf) {
throw std::runtime_error("failed to split graph for Flash Attention check");
}
const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
bool fa_device_mismatch = false;
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
ggml_tensor * n = ggml_graph_node(gf, i);
if (n->op != GGML_OP_FLASH_ATTN_EXT) {
continue;
}
ggml_backend_dev_t device_fa = ggml_backend_get_device(
ggml_backend_sched_get_tensor_backend(sched.get(), n));
// TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
const int il = std::stoi(n->name + prefix_len);
ggml_backend_dev_t device_kv = model.dev_layer(il);
if (device_fa != device_kv) {
LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
"is assigned to device %s (usually due to missing support)\n",
__func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
// FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
fa_device_mismatch = true;
break;
}
}
if (fa_device_mismatch) {
cparams.flash_attn = false;
LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
if (ggml_is_quantized(params.type_v)) {
throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
}
} else {
cparams.flash_attn = true;
LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
}
}
// reserve worst-case graph
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
@@ -283,17 +335,6 @@ llama_context::llama_context(
int n_splits_tg = -1;
int n_nodes_tg = -1;
llama_memory_context_ptr mctx;
if (memory) {
LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
mctx = memory->init_full();
if (!mctx) {
throw std::runtime_error("failed to initialize memory module");
}
}
cross.v_embd.clear();
// reserve pp (prompt processing) graph first so that buffers are only allocated once
{
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
@@ -301,48 +342,6 @@ llama_context::llama_context(
throw std::runtime_error("failed to allocate compute pp buffers");
}
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
ggml_backend_sched_alloc_graph(sched.get(), gf);
const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
bool fa_device_mismatch = false;
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
ggml_tensor * n = ggml_graph_node(gf, i);
if (n->op != GGML_OP_FLASH_ATTN_EXT) {
continue;
}
ggml_backend_dev_t device_fa = ggml_backend_get_device(
ggml_backend_sched_get_tensor_backend(sched.get(), n));
// TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
const int il = std::stoi(n->name + prefix_len);
ggml_backend_dev_t device_kv = model.dev_layer(il);
if (device_fa != device_kv) {
LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
"is assigned to device %s (usually due to missing support)\n",
__func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
// FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
fa_device_mismatch = true;
break;
}
}
if (fa_device_mismatch) {
cparams.flash_attn = false;
LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
if (ggml_is_quantized(params.type_v)) {
throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
}
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
if (!gf) {
throw std::runtime_error("failed to allocate compute pp buffers");
}
} else {
cparams.flash_attn = true;
LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
}
}
n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
n_nodes_pp = ggml_graph_n_nodes(gf);
}
@@ -1366,7 +1365,7 @@ llm_graph_result * llama_context::get_gf_res_reserve() const {
return static_cast<llm_graph_result *>(gf_res_reserve.get());
}
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only) {
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
if (n_tokens % n_seqs != 0) {
@@ -1401,7 +1400,9 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
this->n_outputs = save_n_outputs;
// initialize scheduler with the specified graph
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
if (split_only) {
ggml_backend_sched_split_graph(sched.get(), gf);
} else if (!ggml_backend_sched_reserve(sched.get(), gf)) {
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
return nullptr;
}

View File

@@ -196,7 +196,7 @@ public:
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
// reserve a graph with a dummy ubatch of the specified size
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false);
private:
llm_graph_params graph_params(