mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	llama : optimize memory buffers (#2325)
This commit is contained in:
		| @@ -578,18 +578,18 @@ std::vector<llama_token> llama_tokenize(struct llama_context * ctx, const std::s | ||||
| struct llama_context_params llama_context_params_from_gpt_params(const gpt_params & params) { | ||||
|     auto lparams = llama_context_default_params(); | ||||
|  | ||||
|     lparams.n_ctx        = params.n_ctx; | ||||
|     lparams.n_batch      = params.n_batch; | ||||
|     lparams.n_gpu_layers = params.n_gpu_layers; | ||||
|     lparams.main_gpu     = params.main_gpu; | ||||
|     lparams.tensor_split = params.tensor_split; | ||||
|     lparams.low_vram     = params.low_vram; | ||||
|     lparams.seed         = params.seed; | ||||
|     lparams.f16_kv       = params.memory_f16; | ||||
|     lparams.use_mmap     = params.use_mmap; | ||||
|     lparams.use_mlock    = params.use_mlock; | ||||
|     lparams.logits_all   = params.perplexity; | ||||
|     lparams.embedding    = params.embedding; | ||||
|     lparams.n_ctx           = params.n_ctx; | ||||
|     lparams.n_batch         = params.n_batch; | ||||
|     lparams.n_gpu_layers    = params.n_gpu_layers; | ||||
|     lparams.main_gpu        = params.main_gpu; | ||||
|     lparams.tensor_split    = params.tensor_split; | ||||
|     lparams.low_vram        = params.low_vram; | ||||
|     lparams.seed            = params.seed; | ||||
|     lparams.f16_kv          = params.memory_f16; | ||||
|     lparams.use_mmap        = params.use_mmap; | ||||
|     lparams.use_mlock       = params.use_mlock; | ||||
|     lparams.logits_all      = params.perplexity; | ||||
|     lparams.embedding       = params.embedding; | ||||
|     lparams.rope_freq_base  = params.rope_freq_base; | ||||
|     lparams.rope_freq_scale = params.rope_freq_scale; | ||||
|  | ||||
|   | ||||
| @@ -139,17 +139,14 @@ int main(int argc, char ** argv) { | ||||
|                 params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); | ||||
|     } | ||||
|  | ||||
|     // determine the maximum memory usage needed to do inference for the given n_batch and n_predict parameters | ||||
|     // determine the maximum memory usage needed to do inference for the given n_batch and n_ctx parameters | ||||
|     // uncomment the "used_mem" line in llama.cpp to see the results | ||||
|     if (params.mem_test) { | ||||
|         { | ||||
|             const std::vector<llama_token> tmp(params.n_batch, llama_token_bos()); | ||||
|             llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads); | ||||
|         } | ||||
|             fprintf(stderr, "%s: testing memory usage for n_batch = %d, n_ctx = %d\n", __func__, params.n_batch, params.n_ctx); | ||||
|  | ||||
|         { | ||||
|             const std::vector<llama_token> tmp = { 0, }; | ||||
|             llama_eval(ctx, tmp.data(), tmp.size(), params.n_predict - 1, params.n_threads); | ||||
|             const std::vector<llama_token> tmp(params.n_batch, llama_token_bos()); | ||||
|             llama_eval(ctx, tmp.data(), tmp.size(), params.n_ctx, params.n_threads); | ||||
|         } | ||||
|  | ||||
|         llama_print_timings(ctx); | ||||
|   | ||||
							
								
								
									
										104
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										104
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -98,18 +98,17 @@ static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * | ||||
| } | ||||
|  | ||||
| // | ||||
| // memory sizes | ||||
| // memory sizes (calculated for n_batch == 512) | ||||
| // | ||||
|  | ||||
| static const std::map<e_model, size_t> & MEM_REQ_SCRATCH0(int n_ctx) | ||||
| { | ||||
|     static std::map<e_model, size_t> k_sizes = { | ||||
|         /* empirical scaling, still a guess */ | ||||
|         { MODEL_3B,   ((size_t) n_ctx / 16ull + 128ull) * MB }, | ||||
|         { MODEL_7B,   ((size_t) n_ctx / 16ull + 256ull) * MB }, | ||||
|         { MODEL_13B,  ((size_t) n_ctx / 12ull + 256ull) * MB }, | ||||
|         { MODEL_30B,  ((size_t) n_ctx / 10ull + 256ull) * MB }, | ||||
|         { MODEL_65B,  ((size_t) n_ctx /  8ull + 512ull) * MB }, | ||||
|         { MODEL_3B,   ((size_t) n_ctx / 16ull +  92ull) * MB }, | ||||
|         { MODEL_7B,   ((size_t) n_ctx / 16ull + 100ull) * MB }, | ||||
|         { MODEL_13B,  ((size_t) n_ctx / 12ull + 120ull) * MB }, | ||||
|         { MODEL_30B,  ((size_t) n_ctx /  9ull + 160ull) * MB }, | ||||
|         { MODEL_65B,  ((size_t) n_ctx /  6ull + 256ull) * MB }, // guess | ||||
|     }; | ||||
|     return k_sizes; | ||||
| } | ||||
| @@ -117,38 +116,24 @@ static const std::map<e_model, size_t> & MEM_REQ_SCRATCH0(int n_ctx) | ||||
| static const std::map<e_model, size_t> & MEM_REQ_SCRATCH1() | ||||
| { | ||||
|     static std::map<e_model, size_t> k_sizes = { | ||||
|         { MODEL_3B,    256ull * MB }, | ||||
|         { MODEL_7B,    512ull * MB }, | ||||
|         { MODEL_13B,   512ull * MB }, | ||||
|         { MODEL_30B,   512ull * MB }, | ||||
|         { MODEL_65B,  1024ull * MB }, | ||||
|         { MODEL_3B,  128ull * MB }, | ||||
|         { MODEL_7B,  160ull * MB }, | ||||
|         { MODEL_13B, 192ull * MB }, | ||||
|         { MODEL_30B, 256ull * MB }, | ||||
|         { MODEL_65B, 384ull * MB }, // guess | ||||
|     }; | ||||
|     return k_sizes; | ||||
| } | ||||
|  | ||||
| // 2*n_embd*n_ctx*n_layer*sizeof(float16) | ||||
| static const std::map<e_model, size_t> & MEM_REQ_KV_SELF() | ||||
| // used to store the compute graph tensors + non-scratch data | ||||
| static const std::map<e_model, size_t> & MEM_REQ_EVAL() | ||||
| { | ||||
|     static std::map<e_model, size_t> k_sizes = { | ||||
|         { MODEL_3B,    682ull * MB }, | ||||
|         { MODEL_7B,   1026ull * MB }, | ||||
|         { MODEL_13B,  1608ull * MB }, | ||||
|         { MODEL_30B,  3124ull * MB }, | ||||
|         { MODEL_65B,  5120ull * MB }, | ||||
|     }; | ||||
|     return k_sizes; | ||||
| } | ||||
|  | ||||
| // this is mostly needed for temporary mul_mat buffers to dequantize the data | ||||
| // not actually needed if BLAS is disabled | ||||
| static const std::map<e_model, size_t> & MEM_REQ_EVAL(int n_ctx) | ||||
| { | ||||
|     static std::map<e_model, size_t> k_sizes = { | ||||
|         { MODEL_3B,  ((size_t) n_ctx / 256ull +  512ull) * MB }, | ||||
|         { MODEL_7B,  ((size_t) n_ctx / 256ull +  768ull) * MB }, | ||||
|         { MODEL_13B, ((size_t) n_ctx / 256ull + 1024ull) * MB }, | ||||
|         { MODEL_30B, ((size_t) n_ctx / 256ull + 1280ull) * MB }, | ||||
|         { MODEL_65B, ((size_t) n_ctx / 256ull + 1536ull) * MB }, | ||||
|         { MODEL_3B,   8ull * MB }, | ||||
|         { MODEL_7B,  10ull * MB }, | ||||
|         { MODEL_13B, 12ull * MB }, | ||||
|         { MODEL_30B, 16ull * MB }, | ||||
|         { MODEL_65B, 24ull * MB }, // guess | ||||
|     }; | ||||
|     return k_sizes; | ||||
| } | ||||
| @@ -199,6 +184,15 @@ struct llama_hparams { | ||||
|     bool operator!=(const llama_hparams & other) const { | ||||
|         return static_cast<bool>(memcmp(this, &other, sizeof(llama_hparams))); | ||||
|     } | ||||
|  | ||||
|     size_t kv_size() const { | ||||
|         size_t result = 2ull; | ||||
|         result *= (size_t) n_embd; | ||||
|         result *= (size_t) n_ctx; | ||||
|         result *= (size_t) n_layer; | ||||
|         result *= sizeof(ggml_fp16_t); | ||||
|         return result; | ||||
|     } | ||||
| }; | ||||
|  | ||||
| struct llama_layer { | ||||
| @@ -1069,7 +1063,7 @@ static void llama_model_load_internal( | ||||
|     { | ||||
|         model.buf.resize(ctx_size); | ||||
|         if (use_mlock) { | ||||
|             model.mlock_buf.init(model.buf.addr); | ||||
|             model.mlock_buf.init   (model.buf.addr); | ||||
|             model.mlock_buf.grow_to(model.buf.size); | ||||
|         } | ||||
|  | ||||
| @@ -1186,11 +1180,11 @@ static void llama_model_load_internal( | ||||
|             mmapped_size - vram_weights + // weights in VRAM not in memory | ||||
|             MEM_REQ_SCRATCH0(hparams.n_ctx).at(model.type) + | ||||
|             MEM_REQ_SCRATCH1().at(model.type) + | ||||
|             MEM_REQ_EVAL(hparams.n_ctx).at(model.type); | ||||
|             MEM_REQ_EVAL().at(model.type); | ||||
|  | ||||
|         // this is the memory required by one llama_state | ||||
|         const size_t mem_required_state = | ||||
|             scale*MEM_REQ_KV_SELF().at(model.type); | ||||
|             scale*hparams.kv_size(); | ||||
|  | ||||
|         fprintf(stderr, "%s: mem required  = %7.2f MB (+ %7.2f MB per state)\n", __func__, | ||||
|                 mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0); | ||||
| @@ -1231,7 +1225,7 @@ static void llama_model_load_internal( | ||||
|                 fprintf(stderr, "%s: cannot offload v cache to GPU due to low VRAM option\n", __func__); | ||||
|             } else { | ||||
|                 fprintf(stderr, "%s: offloading v cache to GPU\n", __func__); | ||||
|                 vram_kv_cache += MEM_REQ_KV_SELF().at(model.type) / 2; | ||||
|                 vram_kv_cache += hparams.kv_size() / 2; | ||||
|             } | ||||
|         } | ||||
|         if (n_gpu_layers > (int) hparams.n_layer + 2) { | ||||
| @@ -1239,7 +1233,7 @@ static void llama_model_load_internal( | ||||
|                 fprintf(stderr, "%s: cannot offload k cache to GPU due to low VRAM option\n", __func__); | ||||
|             } else { | ||||
|                 fprintf(stderr, "%s: offloading k cache to GPU\n", __func__); | ||||
|                 vram_kv_cache += MEM_REQ_KV_SELF().at(model.type) / 2; | ||||
|                 vram_kv_cache += hparams.kv_size() / 2; | ||||
|             } | ||||
|         } | ||||
| #elif defined(GGML_USE_CLBLAST) | ||||
| @@ -1739,10 +1733,12 @@ static bool llama_eval_internal( | ||||
|     } | ||||
|  | ||||
| #if 0 | ||||
|     printf("\n%s: used_mem = %.3f MB, scratch -- %.3f MB %.3f MB\n", __func__, | ||||
|     printf("\n%s: used_mem: eval ctx %.3f MB, scratch %.3f MB %.3f MB, work buf %.3f MB, n_past = %d, N = %d\n", __func__, | ||||
|             ggml_used_mem(ctx0)/1024.0/1024.0, | ||||
|             lctx.get_buf_max_mem(0)/1024.0/1024.0, | ||||
|             lctx.get_buf_max_mem(1)/1024.0/1024.0); | ||||
|             lctx.get_buf_max_mem(1)/1024.0/1024.0, | ||||
|             lctx.work_buffer.size()/1024.0/1024.0, | ||||
|             n_past, N); | ||||
| #endif | ||||
|  | ||||
|     ggml_free(ctx0); | ||||
| @@ -2448,8 +2444,8 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s | ||||
|         case LLAMA_FTYPE_MOSTLY_Q5_0: quantized_type = GGML_TYPE_Q5_0; break; | ||||
|         case LLAMA_FTYPE_MOSTLY_Q5_1: quantized_type = GGML_TYPE_Q5_1; break; | ||||
|         case LLAMA_FTYPE_MOSTLY_Q8_0: quantized_type = GGML_TYPE_Q8_0; break; | ||||
|         case LLAMA_FTYPE_MOSTLY_F16: quantized_type = GGML_TYPE_F16; break; | ||||
|         case LLAMA_FTYPE_ALL_F32: quantized_type = GGML_TYPE_F32; break; | ||||
|         case LLAMA_FTYPE_MOSTLY_F16:  quantized_type = GGML_TYPE_F16;  break; | ||||
|         case LLAMA_FTYPE_ALL_F32:     quantized_type = GGML_TYPE_F32;  break; | ||||
|  | ||||
| #ifdef GGML_USE_K_QUANTS | ||||
|         // K-quants | ||||
| @@ -2533,16 +2529,6 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s | ||||
|         } else { | ||||
|             new_type = quantized_type; | ||||
| #ifdef GGML_USE_K_QUANTS | ||||
|             bool convert_incompatible_tensor = false; | ||||
|             if (quantized_type == GGML_TYPE_Q2_K || quantized_type == GGML_TYPE_Q3_K || quantized_type == GGML_TYPE_Q4_K || | ||||
|                 quantized_type == GGML_TYPE_Q5_K || quantized_type == GGML_TYPE_Q6_K) { | ||||
|                 int nx = tensor.ne.at(0); | ||||
|                 int ny = tensor.ne.at(1); | ||||
|                 if (nx % QK_K != 0 || ny % QK_K != 0) { | ||||
|                     fprintf(stderr, "\n\nTensor sizes %d x %d are not divisible by %d, required for k-quants.\n",nx,ny,QK_K); | ||||
|                     convert_incompatible_tensor = true; | ||||
|                 } | ||||
|             } | ||||
|             if (tensor.name == "output.weight") { | ||||
|                 int nx = tensor.ne.at(0); | ||||
|                 int ny = tensor.ne.at(1); | ||||
| @@ -2568,6 +2554,16 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s | ||||
|                 if      (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M || ftype == LLAMA_FTYPE_MOSTLY_Q2_K) new_type = GGML_TYPE_Q4_K; | ||||
|                 else if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_L) new_type = GGML_TYPE_Q5_K; | ||||
|             } | ||||
|             bool convert_incompatible_tensor = false; | ||||
|             if (new_type == GGML_TYPE_Q2_K || new_type == GGML_TYPE_Q3_K || new_type == GGML_TYPE_Q4_K || | ||||
|                 new_type == GGML_TYPE_Q5_K || new_type == GGML_TYPE_Q6_K) { | ||||
|                 int nx = tensor.ne.at(0); | ||||
|                 int ny = tensor.ne.at(1); | ||||
|                 if (nx % QK_K != 0 || ny % QK_K != 0) { | ||||
|                     fprintf(stderr, "\n\nTensor sizes %d x %d are not divisible by %d, required for k-quants.\n",nx,ny,QK_K); | ||||
|                     convert_incompatible_tensor = true; | ||||
|                 } | ||||
|             } | ||||
|             if (convert_incompatible_tensor) { | ||||
|                 if (tensor.name == "output.weight") { | ||||
|                     new_type = GGML_TYPE_F16; //fall back to F16 instead of just failing. | ||||
| @@ -2594,7 +2590,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s | ||||
|                 f32_data = (float *) f32_conv_buf.addr; | ||||
|             } | ||||
|  | ||||
|             printf("quantizing .. "); | ||||
|             printf("quantizing to %s .. ", ggml_type_name(new_type)); | ||||
|             fflush(stdout); | ||||
|  | ||||
|             work.resize(nelements * 4); // upper bound on size | ||||
| @@ -2775,7 +2771,7 @@ struct llama_context * llama_new_context_with_model( | ||||
|             ctx->embedding.resize(hparams.n_embd); | ||||
|         } | ||||
|  | ||||
|         ctx->buf_compute.resize(MEM_REQ_EVAL(hparams.n_ctx).at(ctx->model.type)); | ||||
|         ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type)); | ||||
|  | ||||
|         ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type)); | ||||
|         ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type)); | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Georgi Gerganov
					Georgi Gerganov