mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	llama : separate compute buffer reserve from fattn check (#15696)
Exposes ggml_backend_sched_split_graph() to allow splitting the graph without allocating compute buffers and uses it to split the graph for the automatic Flash Attention check.
This commit is contained in:
		| @@ -307,6 +307,9 @@ extern "C" { | |||||||
|     GGML_API void                 ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend); |     GGML_API void                 ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend); | ||||||
|     GGML_API ggml_backend_t       ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node); |     GGML_API ggml_backend_t       ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node); | ||||||
|  |  | ||||||
|  |     // Split graph without allocating it | ||||||
|  |     GGML_API void                 ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); | ||||||
|  |  | ||||||
|     // Allocate and compute graph on the backend scheduler |     // Allocate and compute graph on the backend scheduler | ||||||
|     GGML_API bool                 ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); // returns success |     GGML_API bool                 ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); // returns success | ||||||
|     GGML_API enum ggml_status     ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph); |     GGML_API enum ggml_status     ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph); | ||||||
|   | |||||||
| @@ -902,7 +902,7 @@ static void ggml_backend_sched_set_if_supported(ggml_backend_sched_t sched, stru | |||||||
| } | } | ||||||
|  |  | ||||||
| // assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend | // assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend | ||||||
| static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { | void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { | ||||||
|     // reset splits |     // reset splits | ||||||
|     sched->n_splits = 0; |     sched->n_splits = 0; | ||||||
|     sched->n_graph_inputs = 0; |     sched->n_graph_inputs = 0; | ||||||
| @@ -1687,6 +1687,8 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * | |||||||
|     GGML_ASSERT(sched); |     GGML_ASSERT(sched); | ||||||
|     GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs); |     GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs); | ||||||
|  |  | ||||||
|  |     ggml_backend_sched_reset(sched); | ||||||
|  |  | ||||||
|     ggml_backend_sched_synchronize(sched); |     ggml_backend_sched_synchronize(sched); | ||||||
|  |  | ||||||
|     ggml_backend_sched_split_graph(sched, measure_graph); |     ggml_backend_sched_split_graph(sched, measure_graph); | ||||||
|   | |||||||
| @@ -270,19 +270,7 @@ llama_context::llama_context( | |||||||
|         } |         } | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     // resolve automatic Flash Attention use and reserve worst-case graph |  | ||||||
|     if (!hparams.vocab_only) { |     if (!hparams.vocab_only) { | ||||||
|         const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max; |  | ||||||
|         const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); |  | ||||||
|  |  | ||||||
|         LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); |  | ||||||
|  |  | ||||||
|         int n_splits_pp = -1; |  | ||||||
|         int n_nodes_pp  = -1; |  | ||||||
|  |  | ||||||
|         int n_splits_tg = -1; |  | ||||||
|         int n_nodes_tg  = -1; |  | ||||||
|  |  | ||||||
|         llama_memory_context_ptr mctx; |         llama_memory_context_ptr mctx; | ||||||
|         if (memory) { |         if (memory) { | ||||||
|             LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__); |             LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__); | ||||||
| @@ -293,16 +281,12 @@ llama_context::llama_context( | |||||||
|         } |         } | ||||||
|  |  | ||||||
|         cross.v_embd.clear(); |         cross.v_embd.clear(); | ||||||
|  |         // resolve automatic Flash Attention use | ||||||
|         // reserve pp (prompt processing) graph first so that buffers are only allocated once |  | ||||||
|         { |  | ||||||
|             auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); |  | ||||||
|             if (!gf) { |  | ||||||
|                 throw std::runtime_error("failed to allocate compute pp buffers"); |  | ||||||
|             } |  | ||||||
|  |  | ||||||
|         if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) { |         if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) { | ||||||
|                 ggml_backend_sched_alloc_graph(sched.get(), gf); |             auto * gf = graph_reserve(1, 1, 0, mctx.get(), true); | ||||||
|  |             if (!gf) { | ||||||
|  |                 throw std::runtime_error("failed to split graph for Flash Attention check"); | ||||||
|  |             } | ||||||
|  |  | ||||||
|             const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1; |             const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1; | ||||||
|             bool fa_device_mismatch = false; |             bool fa_device_mismatch = false; | ||||||
| @@ -333,16 +317,31 @@ llama_context::llama_context( | |||||||
|                 if (ggml_is_quantized(params.type_v)) { |                 if (ggml_is_quantized(params.type_v)) { | ||||||
|                     throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention"); |                     throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention"); | ||||||
|                 } |                 } | ||||||
|                     auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); |  | ||||||
|                     if (!gf) { |  | ||||||
|                         throw std::runtime_error("failed to allocate compute pp buffers"); |  | ||||||
|                     } |  | ||||||
|             } else { |             } else { | ||||||
|                 cparams.flash_attn = true; |                 cparams.flash_attn = true; | ||||||
|                 LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__); |                 LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__); | ||||||
|             } |             } | ||||||
|         } |         } | ||||||
|  |  | ||||||
|  |         // reserve worst-case graph | ||||||
|  |         const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max; | ||||||
|  |         const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); | ||||||
|  |  | ||||||
|  |         LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); | ||||||
|  |  | ||||||
|  |         int n_splits_pp = -1; | ||||||
|  |         int n_nodes_pp  = -1; | ||||||
|  |  | ||||||
|  |         int n_splits_tg = -1; | ||||||
|  |         int n_nodes_tg  = -1; | ||||||
|  |  | ||||||
|  |         // reserve pp (prompt processing) graph first so that buffers are only allocated once | ||||||
|  |         { | ||||||
|  |             auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); | ||||||
|  |             if (!gf) { | ||||||
|  |                 throw std::runtime_error("failed to allocate compute pp buffers"); | ||||||
|  |             } | ||||||
|  |  | ||||||
|             n_splits_pp = ggml_backend_sched_get_n_splits(sched.get()); |             n_splits_pp = ggml_backend_sched_get_n_splits(sched.get()); | ||||||
|             n_nodes_pp  = ggml_graph_n_nodes(gf); |             n_nodes_pp  = ggml_graph_n_nodes(gf); | ||||||
|         } |         } | ||||||
| @@ -1366,7 +1365,7 @@ llm_graph_result * llama_context::get_gf_res_reserve() const { | |||||||
|     return static_cast<llm_graph_result *>(gf_res_reserve.get()); |     return static_cast<llm_graph_result *>(gf_res_reserve.get()); | ||||||
| } | } | ||||||
|  |  | ||||||
| ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) { | ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only) { | ||||||
|     LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs); |     LLAMA_LOG_DEBUG("%s: reserving a graph for ubatch with n_tokens = %4u, n_seqs = %2u, n_outputs = %4u\n", __func__, n_tokens, n_seqs, n_outputs); | ||||||
|  |  | ||||||
|     if (n_tokens % n_seqs != 0) { |     if (n_tokens % n_seqs != 0) { | ||||||
| @@ -1401,7 +1400,9 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u | |||||||
|     this->n_outputs = save_n_outputs; |     this->n_outputs = save_n_outputs; | ||||||
|  |  | ||||||
|     // initialize scheduler with the specified graph |     // initialize scheduler with the specified graph | ||||||
|     if (!ggml_backend_sched_reserve(sched.get(), gf)) { |     if (split_only) { | ||||||
|  |         ggml_backend_sched_split_graph(sched.get(), gf); | ||||||
|  |     } else if (!ggml_backend_sched_reserve(sched.get(), gf)) { | ||||||
|         LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); |         LLAMA_LOG_ERROR("%s: failed to allocate compute buffers\n", __func__); | ||||||
|         return nullptr; |         return nullptr; | ||||||
|     } |     } | ||||||
|   | |||||||
| @@ -196,7 +196,7 @@ public: | |||||||
|     ggml_status graph_compute(ggml_cgraph * gf, bool batched); |     ggml_status graph_compute(ggml_cgraph * gf, bool batched); | ||||||
|  |  | ||||||
|     // reserve a graph with a dummy ubatch of the specified size |     // reserve a graph with a dummy ubatch of the specified size | ||||||
|     ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx); |     ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false); | ||||||
|  |  | ||||||
| private: | private: | ||||||
|     llm_graph_params graph_params( |     llm_graph_params graph_params( | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Diego Devesa
					Diego Devesa