mirror of
https://github.com/ggml-org/llama.cpp.git
synced 2025-11-01 09:01:57 +00:00
llama : separate compute buffer reserve from fattn check (#15696)
Exposes ggml_backend_sched_split_graph() to allow splitting the graph without allocating compute buffers and uses it to split the graph for the automatic Flash Attention check.
This commit is contained in:
@@ -307,6 +307,9 @@ extern "C" {
|
|||||||
GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
|
GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
|
||||||
GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
|
GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
|
||||||
|
|
||||||
|
// Split graph without allocating it
|
||||||
|
GGML_API void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
|
||||||
|
|
||||||
// Allocate and compute graph on the backend scheduler
|
// Allocate and compute graph on the backend scheduler
|
||||||
GGML_API bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); // returns success
|
GGML_API bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); // returns success
|
||||||
GGML_API enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
|
GGML_API enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
|
||||||
|
|||||||
@@ -902,7 +902,7 @@ static void ggml_backend_sched_set_if_supported(ggml_backend_sched_t sched, stru
|
|||||||
}
|
}
|
||||||
|
|
||||||
// assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
|
// assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
|
||||||
static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
|
void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
|
||||||
// reset splits
|
// reset splits
|
||||||
sched->n_splits = 0;
|
sched->n_splits = 0;
|
||||||
sched->n_graph_inputs = 0;
|
sched->n_graph_inputs = 0;
|
||||||
@@ -1687,6 +1687,8 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
|
|||||||
GGML_ASSERT(sched);
|
GGML_ASSERT(sched);
|
||||||
GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);
|
GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);
|
||||||
|
|
||||||
|
ggml_backend_sched_reset(sched);
|
||||||
|
|
||||||
ggml_backend_sched_synchronize(sched);
|
ggml_backend_sched_synchronize(sched);
|
||||||
|
|
||||||
ggml_backend_sched_split_graph(sched, measure_graph);
|
ggml_backend_sched_split_graph(sched, measure_graph);
|
||||||
|
|||||||
@@ -270,8 +270,60 @@ llama_context::llama_context(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// resolve automatic Flash Attention use and reserve worst-case graph
|
|
||||||
if (!hparams.vocab_only) {
|
if (!hparams.vocab_only) {
|
||||||
|
llama_memory_context_ptr mctx;
|
||||||
|
if (memory) {
|
||||||
|
LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
|
||||||
|
mctx = memory->init_full();
|
||||||
|
if (!mctx) {
|
||||||
|
throw std::runtime_error("failed to initialize memory module");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cross.v_embd.clear();
|
||||||
|
// resolve automatic Flash Attention use
|
||||||
|
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
|
||||||
|
auto * gf = graph_reserve(1, 1, 0, mctx.get(), true);
|
||||||
|
if (!gf) {
|
||||||
|
throw std::runtime_error("failed to split graph for Flash Attention check");
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
|
||||||
|
bool fa_device_mismatch = false;
|
||||||
|
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
|
||||||
|
ggml_tensor * n = ggml_graph_node(gf, i);
|
||||||
|
if (n->op != GGML_OP_FLASH_ATTN_EXT) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
ggml_backend_dev_t device_fa = ggml_backend_get_device(
|
||||||
|
ggml_backend_sched_get_tensor_backend(sched.get(), n));
|
||||||
|
|
||||||
|
// TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
|
||||||
|
GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
|
||||||
|
const int il = std::stoi(n->name + prefix_len);
|
||||||
|
ggml_backend_dev_t device_kv = model.dev_layer(il);
|
||||||
|
if (device_fa != device_kv) {
|
||||||
|
LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
|
||||||
|
"is assigned to device %s (usually due to missing support)\n",
|
||||||
|
__func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
|
||||||
|
// FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
|
||||||
|
fa_device_mismatch = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (fa_device_mismatch) {
|
||||||
|
cparams.flash_attn = false;
|
||||||
|
LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
|
||||||
|
if (ggml_is_quantized(params.type_v)) {
|
||||||
|
throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
cparams.flash_attn = true;
|
||||||
|
LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// reserve worst-case graph
|
||||||
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
|
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
|
||||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||||
|
|
||||||
@@ -283,17 +335,6 @@ llama_context::llama_context(
|
|||||||
int n_splits_tg = -1;
|
int n_splits_tg = -1;
|
||||||
int n_nodes_tg = -1;
|
int n_nodes_tg = -1;
|
||||||
|
|
||||||
llama_memory_context_ptr mctx;
|
|
||||||
if (memory) {
|
|
||||||
LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__);
|
|
||||||
mctx = memory->init_full();
|
|
||||||
if (!mctx) {
|
|
||||||
throw std::runtime_error("failed to initialize memory module");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cross.v_embd.clear();
|
|
||||||
|
|
||||||
// reserve pp (prompt processing) graph first so that buffers are only allocated once
|
// reserve pp (prompt processing) graph first so that buffers are only allocated once
|
||||||
{
|
{
|
||||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
||||||
@@ -301,48 +342,6 @@ llama_context::llama_context(
|
|||||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) {
|
|
||||||
ggml_backend_sched_alloc_graph(sched.get(), gf);
|
|
||||||
|
|
||||||
const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1;
|
|
||||||
bool fa_device_mismatch = false;
|
|
||||||
for (int i = 0; i < ggml_graph_n_nodes(gf); i++) {
|
|
||||||
ggml_tensor * n = ggml_graph_node(gf, i);
|
|
||||||
if (n->op != GGML_OP_FLASH_ATTN_EXT) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
ggml_backend_dev_t device_fa = ggml_backend_get_device(
|
|
||||||
ggml_backend_sched_get_tensor_backend(sched.get(), n));
|
|
||||||
|
|
||||||
// TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer
|
|
||||||
GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0);
|
|
||||||
const int il = std::stoi(n->name + prefix_len);
|
|
||||||
ggml_backend_dev_t device_kv = model.dev_layer(il);
|
|
||||||
if (device_fa != device_kv) {
|
|
||||||
LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor "
|
|
||||||
"is assigned to device %s (usually due to missing support)\n",
|
|
||||||
__func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa));
|
|
||||||
// FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways
|
|
||||||
fa_device_mismatch = true;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (fa_device_mismatch) {
|
|
||||||
cparams.flash_attn = false;
|
|
||||||
LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__);
|
|
||||||
if (ggml_is_quantized(params.type_v)) {
|
|
||||||
throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention");
|
|
||||||
}
|
|
||||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
|
||||||
if (!gf) {
|
|
||||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
cparams.flash_attn = true;
|
|
||||||
LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
|
n_splits_pp = ggml_backend_sched_get_n_splits(sched.get());
|
||||||
n_nodes_pp = ggml_graph_n_nodes(gf);
|
n_nodes_pp = ggml_graph_n_nodes(gf);
|
||||||
}
|
}
|
||||||
@@ -1366,7 +1365,7 @@ llm_graph_result * llama_context::get_gf_res_reserve() const {
|
|||||||
return static_cast<llm_graph_result *>(gf_res_reserve.get());
|
return static_cast<llm_graph_result *>(gf_res_reserve.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) {
|
ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only) {
|
||||||
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
|
LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs);
|
||||||
|
|
||||||
if (n_tokens % n_seqs != 0) {
|
if (n_tokens % n_seqs != 0) {
|
||||||
@@ -1401,7 +1400,9 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
|
|||||||
this->n_outputs = save_n_outputs;
|
this->n_outputs = save_n_outputs;
|
||||||
|
|
||||||
// initialize scheduler with the specified graph
|
// initialize scheduler with the specified graph
|
||||||
if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
if (split_only) {
|
||||||
|
ggml_backend_sched_split_graph(sched.get(), gf);
|
||||||
|
} else if (!ggml_backend_sched_reserve(sched.get(), gf)) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -196,7 +196,7 @@ public:
|
|||||||
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
|
ggml_status graph_compute(ggml_cgraph * gf, bool batched);
|
||||||
|
|
||||||
// reserve a graph with a dummy ubatch of the specified size
|
// reserve a graph with a dummy ubatch of the specified size
|
||||||
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx);
|
ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
llm_graph_params graph_params(
|
llm_graph_params graph_params(
|
||||||
|
|||||||
Reference in New Issue
Block a user