mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	ggml : add graph tensor allocator (#2411)
* ggml : add graph tensor allocator * ggml : don't calculate data pointer of unallocated tensors when creating a view with an offset * ggml : refactor ggml_view_Nd into ggml_view_tensor_offset
This commit is contained in:
		
							
								
								
									
										242
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										242
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -56,8 +56,14 @@ | ||||
| #pragma warning(disable: 4244 4267) // possible loss of data | ||||
| #endif | ||||
|  | ||||
| #if !defined(GGML_USE_CUBLAS) && !defined(GGML_USE_METAL) | ||||
| #include "ggml-alloc.h" | ||||
| #define LLAMA_USE_ALLOCATOR | ||||
| #else | ||||
| #define LLAMA_USE_SCRATCH | ||||
| #define LLAMA_MAX_SCRATCH_BUFFERS 16 | ||||
| #endif | ||||
|  | ||||
|  | ||||
| // available llama models | ||||
| enum e_model { | ||||
| @@ -327,13 +333,22 @@ struct llama_model { | ||||
|  | ||||
| struct llama_context { | ||||
|     llama_context(const llama_model & model) : model(model), t_load_us(model.t_load_us), t_start_us(model.t_start_us) {} | ||||
| #ifdef GGML_USE_METAL | ||||
|     ~llama_context() { | ||||
|         if (model_owner) { | ||||
|             delete &model; | ||||
|         } | ||||
| #ifdef GGML_USE_METAL | ||||
|         if (ctx_metal) { | ||||
|             ggml_metal_free(ctx_metal); | ||||
|         } | ||||
|     } | ||||
| #endif | ||||
| #ifdef LLAMA_USE_ALLOCATOR | ||||
|         if (alloc) { | ||||
|             ggml_allocr_free(alloc); | ||||
|         } | ||||
| #endif | ||||
|     } | ||||
|  | ||||
|     std::mt19937 rng; | ||||
|  | ||||
|     bool has_evaluated_once = false; | ||||
| @@ -371,7 +386,17 @@ struct llama_context { | ||||
|     // memory buffers used to evaluate the model | ||||
|     // TODO: move in llama_state | ||||
|     llama_ctx_buffer buf_compute; | ||||
|  | ||||
| #ifdef LLAMA_USE_ALLOCATOR | ||||
|     llama_ctx_buffer buf_alloc; | ||||
|     ggml_allocr * alloc = NULL; | ||||
| #endif | ||||
|  | ||||
| #ifdef LLAMA_USE_SCRATCH | ||||
|     llama_ctx_buffer buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS]; | ||||
|     int    buf_last = 0; | ||||
|     size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 }; | ||||
| #endif | ||||
|  | ||||
| #ifdef GGML_USE_METAL | ||||
|     ggml_metal_context * ctx_metal = NULL; | ||||
| @@ -381,9 +406,6 @@ struct llama_context { | ||||
|     ggml_mpi_context * ctx_mpi = NULL; | ||||
| #endif | ||||
|  | ||||
|     int    buf_last = 0; | ||||
|     size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 }; | ||||
|  | ||||
|     void use_buf(struct ggml_context * ctx, int i) { | ||||
| #if defined(LLAMA_USE_SCRATCH) | ||||
|         size_t last_size = 0; | ||||
| @@ -1230,12 +1252,16 @@ static void llama_model_load_internal( | ||||
|         const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1; | ||||
|  | ||||
|         // this is the total memory required to run the inference | ||||
|         const size_t mem_required = | ||||
|         size_t mem_required = | ||||
|             ctx_size + | ||||
|             mmapped_size - vram_weights + // weights in VRAM not in memory | ||||
|             mmapped_size - vram_weights; // weights in VRAM not in memory | ||||
|  | ||||
| #ifndef LLAMA_USE_ALLOCATOR | ||||
|         mem_required += | ||||
|             MEM_REQ_SCRATCH0(hparams.n_ctx).at(model.type) + | ||||
|             MEM_REQ_SCRATCH1().at(model.type) + | ||||
|             MEM_REQ_EVAL().at(model.type); | ||||
| #endif | ||||
|  | ||||
|         // this is the memory required by one llama_state | ||||
|         const size_t mem_required_state = | ||||
| @@ -1360,32 +1386,15 @@ static bool llama_model_load( | ||||
|     } | ||||
| } | ||||
|  | ||||
| // evaluate the transformer | ||||
| // | ||||
| //   - lctx:      llama context | ||||
| //   - tokens:    new batch of tokens to process | ||||
| //   - embd       embeddings input | ||||
| //   - n_tokens   number of tokens | ||||
| //   - n_past:    the context size so far | ||||
| //   - n_threads: number of threads to use | ||||
| // | ||||
| static bool llama_eval_internal( | ||||
| static struct ggml_cgraph * llama_build_graph( | ||||
|          llama_context & lctx, | ||||
|      const llama_token * tokens, | ||||
|            const float * embd, | ||||
|                    int   n_tokens, | ||||
|                    int   n_past, | ||||
|                    int   n_threads, | ||||
|             const char * cgraph_fname) { | ||||
|                    int   n_past) { | ||||
|  | ||||
|     LLAMA_ASSERT((!tokens && embd) || (tokens && !embd)); | ||||
|  | ||||
| #ifdef GGML_USE_MPI | ||||
|     ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads); | ||||
| #endif | ||||
|  | ||||
|     const int64_t t_start_us = ggml_time_us(); | ||||
|  | ||||
|     const int N = n_tokens; | ||||
|  | ||||
|     const auto & model   = lctx.model; | ||||
| @@ -1401,10 +1410,8 @@ static bool llama_eval_internal( | ||||
|     const int64_t n_head      = hparams.n_head; | ||||
|     const int64_t n_head_kv   = hparams.n_head_kv; | ||||
|     const int64_t n_embd_head = hparams.n_embd_head(); | ||||
|     const int64_t n_vocab     = hparams.n_vocab; | ||||
|     const int64_t n_embd_gqa  = hparams.n_embd_gqa(); | ||||
|  | ||||
|  | ||||
|     LLAMA_ASSERT(n_embd_head == hparams.n_rot); | ||||
|  | ||||
|     const float freq_base  = hparams.rope_freq_base; | ||||
| @@ -1416,26 +1423,35 @@ static bool llama_eval_internal( | ||||
|     auto & mem_per_token = lctx.mem_per_token; | ||||
|     auto & buf_compute   = lctx.buf_compute; | ||||
|  | ||||
|  | ||||
|     struct ggml_init_params params = { | ||||
|         /*.mem_size   =*/ buf_compute.size, | ||||
|         /*.mem_buffer =*/ buf_compute.addr, | ||||
|         /*.no_alloc   =*/ false, | ||||
|     }; | ||||
|  | ||||
| #ifdef LLAMA_USE_ALLOCATOR | ||||
|     params.no_alloc = true; | ||||
| #endif | ||||
|  | ||||
|     struct ggml_context * ctx0 = ggml_init(params); | ||||
|  | ||||
|     ggml_cgraph * gf = ggml_new_graph(ctx0); | ||||
|  | ||||
|     // for big prompts, if BLAS is enabled, it is better to use only one thread | ||||
|     // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance | ||||
|     n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads; | ||||
|  | ||||
|     struct ggml_tensor * cur; | ||||
|     struct ggml_tensor * inpL; | ||||
|  | ||||
|     if (tokens) { | ||||
|         struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); | ||||
|  | ||||
| #ifdef LLAMA_USE_ALLOCATOR | ||||
|         ggml_allocr_alloc(lctx.alloc, inp_tokens); | ||||
|         if (!ggml_allocr_is_measure(lctx.alloc)) { | ||||
|             memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens)); | ||||
|         } | ||||
| #else | ||||
|         memcpy(inp_tokens->data, tokens, N*ggml_element_size(inp_tokens)); | ||||
| #endif | ||||
|         ggml_set_name(inp_tokens, "inp_tokens"); | ||||
|  | ||||
|         inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); | ||||
| @@ -1445,7 +1461,15 @@ static bool llama_eval_internal( | ||||
| #endif | ||||
|  | ||||
|         inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N); | ||||
|  | ||||
| #ifdef LLAMA_USE_ALLOCATOR | ||||
|         ggml_allocr_alloc(lctx.alloc, inpL); | ||||
|         if (!ggml_allocr_is_measure(lctx.alloc)) { | ||||
|             memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); | ||||
|         } | ||||
| #else | ||||
|         memcpy(inpL->data, embd, N * n_embd * ggml_element_size(inpL)); | ||||
| #endif | ||||
|     } | ||||
|  | ||||
|     const int i_gpu_start = n_layer - n_gpu_layers; | ||||
| @@ -1472,6 +1496,17 @@ static bool llama_eval_internal( | ||||
|     } | ||||
| #endif // GGML_USE_CUBLAS | ||||
|  | ||||
|     struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); | ||||
| #ifdef LLAMA_USE_ALLOCATOR | ||||
|     ggml_allocr_alloc(lctx.alloc, KQ_scale); | ||||
|     if (!ggml_allocr_is_measure(lctx.alloc)) { | ||||
|         ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); | ||||
|     } | ||||
| #else | ||||
|     ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); | ||||
| #endif | ||||
|     ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); | ||||
|  | ||||
|     for (int il = 0; il < n_layer; ++il) { | ||||
|         ggml_format_name(inpL, "layer_inp_%d", il); | ||||
|  | ||||
| @@ -1567,9 +1602,6 @@ static bool llama_eval_internal( | ||||
|             ggml_set_name(KQ, "KQ"); | ||||
|  | ||||
|             // KQ_scaled = KQ / sqrt(n_embd_head) | ||||
|             struct ggml_tensor * KQ_scale = ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)); | ||||
|             ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); | ||||
|  | ||||
|             // KQ_scaled shape [n_past + N, N, n_head, 1] | ||||
|             struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); | ||||
|             offload_func_kq(KQ_scaled); | ||||
| @@ -1685,9 +1717,6 @@ static bool llama_eval_internal( | ||||
|  | ||||
|     lctx.use_buf(ctx0, 0); | ||||
|  | ||||
|     // used at the end to optionally extract the embeddings | ||||
|     struct ggml_tensor * embeddings = NULL; | ||||
|  | ||||
|     // norm | ||||
|     { | ||||
|         cur = ggml_rms_norm(ctx0, inpL, rms_norm_eps); | ||||
| @@ -1698,8 +1727,6 @@ static bool llama_eval_internal( | ||||
|         cur = ggml_mul(ctx0, cur, model.norm); | ||||
|         // offload_func_nr(cur); // TODO CPU + GPU mirrored backend | ||||
|         ggml_set_name(cur, "result_norm"); | ||||
|  | ||||
|         embeddings = cur; | ||||
|     } | ||||
|  | ||||
|     // lm_head | ||||
| @@ -1711,12 +1738,82 @@ static bool llama_eval_internal( | ||||
|     // logits -> probs | ||||
|     //cur = ggml_soft_max_inplace(ctx0, cur); | ||||
|  | ||||
|     // run the computation | ||||
|     ggml_build_forward_expand(gf, cur); | ||||
|  | ||||
|     // fprintf(stderr, "graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf.n_nodes, gf.n_leafs); | ||||
|     if (mem_per_token == 0) { | ||||
|         mem_per_token = ggml_used_mem(ctx0)/N; | ||||
|     } | ||||
|  | ||||
| #if 0 | ||||
|     printf("\n%s: used_mem: eval ctx %.3f MB, scratch %.3f MB %.3f MB, work buf %.3f MB, n_past = %d, N = %d\n", __func__, | ||||
|             ggml_used_mem(ctx0)/1024.0/1024.0, | ||||
|             lctx.get_buf_max_mem(0)/1024.0/1024.0, | ||||
|             lctx.get_buf_max_mem(1)/1024.0/1024.0, | ||||
|             lctx.work_buffer.size()/1024.0/1024.0, | ||||
|             n_past, N); | ||||
| #endif | ||||
|  | ||||
|     ggml_free(ctx0); | ||||
|  | ||||
|     return gf; | ||||
| } | ||||
|  | ||||
| // evaluate the transformer | ||||
| // | ||||
| //   - lctx:      llama context | ||||
| //   - tokens:    new batch of tokens to process | ||||
| //   - embd       embeddings input | ||||
| //   - n_tokens   number of tokens | ||||
| //   - n_past:    the context size so far | ||||
| //   - n_threads: number of threads to use | ||||
| // | ||||
| static bool llama_eval_internal( | ||||
|          llama_context & lctx, | ||||
|      const llama_token * tokens, | ||||
|            const float * embd, | ||||
|                    int   n_tokens, | ||||
|                    int   n_past, | ||||
|                    int   n_threads, | ||||
|             const char * cgraph_fname) { | ||||
|  | ||||
|     LLAMA_ASSERT((!tokens && embd) || (tokens && !embd)); | ||||
|  | ||||
|     const int64_t t_start_us = ggml_time_us(); | ||||
|  | ||||
| #ifdef GGML_USE_MPI | ||||
|     ggml_mpi_eval_init(lctx.ctx_mpi, &n_tokens, &n_past, &n_threads); | ||||
| #endif | ||||
|  | ||||
|     const int N = n_tokens; | ||||
|  | ||||
|     const auto & model   = lctx.model; | ||||
|     const auto & hparams = model.hparams; | ||||
|  | ||||
|     const auto & kv_self = lctx.kv_self; | ||||
|  | ||||
|     LLAMA_ASSERT(!!kv_self.ctx); | ||||
|  | ||||
|     const int64_t n_embd      = hparams.n_embd; | ||||
|     const int64_t n_vocab     = hparams.n_vocab; | ||||
|  | ||||
| #ifdef LLAMA_USE_ALLOCATOR | ||||
|     ggml_allocr_reset(lctx.alloc); | ||||
| #endif | ||||
|  | ||||
|     ggml_cgraph * gf = llama_build_graph(lctx, tokens, embd, n_tokens, n_past); | ||||
|  | ||||
| #ifdef LLAMA_USE_ALLOCATOR | ||||
|     ggml_allocr_alloc_graph(lctx.alloc, gf); | ||||
| #endif | ||||
|  | ||||
|     // fprintf(stderr, "graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); | ||||
|  | ||||
|     // for big prompts, if BLAS is enabled, it is better to use only one thread | ||||
|     // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance | ||||
|     n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads; | ||||
|  | ||||
| #if GGML_USE_MPI | ||||
|     const int64_t n_layer = hparams.n_layer; | ||||
|     ggml_mpi_graph_compute_pre(lctx.ctx_mpi, gf, n_layer); | ||||
| #endif | ||||
|  | ||||
| @@ -1760,6 +1857,10 @@ static bool llama_eval_internal( | ||||
|     lctx.kv_self.n = n_past + N; | ||||
|  | ||||
|     struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; | ||||
|     struct ggml_tensor * embeddings = gf->nodes[gf->n_nodes - 2]; | ||||
|  | ||||
|     LLAMA_ASSERT(strcmp(res->name, "result_output") == 0); | ||||
|     LLAMA_ASSERT(strcmp(embeddings->name, "result_norm") == 0); | ||||
|  | ||||
|     if (cgraph_fname) { | ||||
|         ggml_graph_export(gf, cgraph_fname); | ||||
| @@ -1798,21 +1899,6 @@ static bool llama_eval_internal( | ||||
|         memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd); | ||||
|     } | ||||
|  | ||||
|     if (mem_per_token == 0) { | ||||
|         mem_per_token = ggml_used_mem(ctx0)/N; | ||||
|     } | ||||
|  | ||||
| #if 0 | ||||
|     printf("\n%s: used_mem: eval ctx %.3f MB, scratch %.3f MB %.3f MB, work buf %.3f MB, n_past = %d, N = %d\n", __func__, | ||||
|             ggml_used_mem(ctx0)/1024.0/1024.0, | ||||
|             lctx.get_buf_max_mem(0)/1024.0/1024.0, | ||||
|             lctx.get_buf_max_mem(1)/1024.0/1024.0, | ||||
|             lctx.work_buffer.size()/1024.0/1024.0, | ||||
|             n_past, N); | ||||
| #endif | ||||
|  | ||||
|     ggml_free(ctx0); | ||||
|  | ||||
|     // measure the performance only for the single-token evals | ||||
|     if (N == 1) { | ||||
|         lctx.t_eval_us += ggml_time_us() - t_start_us; | ||||
| @@ -3180,10 +3266,47 @@ struct llama_context * llama_new_context_with_model( | ||||
|             ctx->embedding.resize(hparams.n_embd); | ||||
|         } | ||||
|  | ||||
|         ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type) + ggml_graph_overhead()); | ||||
| #ifdef LLAMA_USE_ALLOCATOR | ||||
|         { | ||||
|             static const size_t tensor_alignment = 32; | ||||
|             // the compute buffer is used to store the tensor and graph structs, while the allocator buffer is used for the tensor data | ||||
|             ctx->buf_compute.resize(ggml_tensor_overhead()*GGML_MAX_NODES + ggml_graph_overhead()); | ||||
|  | ||||
|             // create measure allocator | ||||
|             ctx->alloc = ggml_allocr_new_measure(tensor_alignment); | ||||
|  | ||||
|             // build worst-case graph | ||||
|             int n_tokens = std::min((int)hparams.n_ctx, params.n_batch); | ||||
|             int n_past = hparams.n_ctx - n_tokens; | ||||
|             llama_token token = llama_token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph | ||||
|             ggml_cgraph * gf = llama_build_graph(*ctx, &token, NULL, n_tokens, n_past); | ||||
|  | ||||
|             // measure memory requirements for the graph | ||||
|             size_t alloc_size = ggml_allocr_alloc_graph(ctx->alloc, gf) + tensor_alignment; | ||||
|  | ||||
|             fprintf(stderr, "%s: compute buffer total size = %7.2f MB\n", __func__, (ctx->buf_compute.size + alloc_size) / 1024.0 / 1024.0); | ||||
|  | ||||
|             // debug - for comparison with scratch buffer | ||||
|             //size_t prev_req = | ||||
|             //    MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type) + | ||||
|             //    MEM_REQ_SCRATCH1().at(ctx->model.type) + | ||||
|             //    MEM_REQ_EVAL().at(ctx->model.type); | ||||
|             //fprintf(stderr, "%s: (debug) equivalent with scratch buffer = %7.2f MB\n", __func__, prev_req / 1024.0 / 1024.0); | ||||
|  | ||||
|             // recreate allocator with exact memory requirements | ||||
|             ggml_allocr_free(ctx->alloc); | ||||
|  | ||||
|             ctx->buf_alloc.resize(alloc_size); | ||||
|             ctx->alloc = ggml_allocr_new(ctx->buf_alloc.addr, ctx->buf_alloc.size, tensor_alignment); | ||||
|         } | ||||
| #else | ||||
|         ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type) + ggml_graph_overhead()); | ||||
| #endif | ||||
|  | ||||
| #ifdef LLAMA_USE_SCRATCH | ||||
|         ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type)); | ||||
|         ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type)); | ||||
| #endif | ||||
|     } | ||||
|  | ||||
| #ifdef GGML_USE_METAL | ||||
| @@ -3253,9 +3376,6 @@ struct llama_context * llama_init_from_file( | ||||
| } | ||||
|  | ||||
| void llama_free(struct llama_context * ctx) { | ||||
|     if (ctx->model_owner) { | ||||
|         delete &ctx->model; | ||||
|     } | ||||
|     delete ctx; | ||||
| } | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 slaren
					slaren