mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	Store layers in VRAM
This commit is contained in:
		| @@ -271,6 +271,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { | |||||||
|             params.use_color = true; |             params.use_color = true; | ||||||
|         } else if (arg == "--mlock") { |         } else if (arg == "--mlock") { | ||||||
|             params.use_mlock = true; |             params.use_mlock = true; | ||||||
|  |         } else if (arg == "--gpu_layers") { | ||||||
|  |             if (++i >= argc) { | ||||||
|  |                 invalid_param = true; | ||||||
|  |                 break; | ||||||
|  |             } | ||||||
|  |             params.gpu_layers = std::stoi(argv[i]); | ||||||
|         } else if (arg == "--no-mmap") { |         } else if (arg == "--no-mmap") { | ||||||
|             params.use_mmap = false; |             params.use_mmap = false; | ||||||
|         } else if (arg == "--mtest") { |         } else if (arg == "--mtest") { | ||||||
| @@ -406,6 +412,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { | |||||||
|     if (llama_mmap_supported()) { |     if (llama_mmap_supported()) { | ||||||
|         fprintf(stderr, "  --no-mmap             do not memory-map model (slower load but may reduce pageouts if not using mlock)\n"); |         fprintf(stderr, "  --no-mmap             do not memory-map model (slower load but may reduce pageouts if not using mlock)\n"); | ||||||
|     } |     } | ||||||
|  |     fprintf(stderr, "  --gpu_layers          number of layers to store in VRAM"); | ||||||
|     fprintf(stderr, "  --mtest               compute maximum memory usage\n"); |     fprintf(stderr, "  --mtest               compute maximum memory usage\n"); | ||||||
|     fprintf(stderr, "  --verbose-prompt      print prompt before generation\n"); |     fprintf(stderr, "  --verbose-prompt      print prompt before generation\n"); | ||||||
|     fprintf(stderr, "  --lora FNAME          apply LoRA adapter (implies --no-mmap)\n"); |     fprintf(stderr, "  --lora FNAME          apply LoRA adapter (implies --no-mmap)\n"); | ||||||
| @@ -454,6 +461,7 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) { | |||||||
|     lparams.f16_kv     = params.memory_f16; |     lparams.f16_kv     = params.memory_f16; | ||||||
|     lparams.use_mmap   = params.use_mmap; |     lparams.use_mmap   = params.use_mmap; | ||||||
|     lparams.use_mlock  = params.use_mlock; |     lparams.use_mlock  = params.use_mlock; | ||||||
|  |     lparams.gpu_layers = params.gpu_layers; | ||||||
|     lparams.logits_all = params.perplexity; |     lparams.logits_all = params.perplexity; | ||||||
|     lparams.embedding  = params.embedding; |     lparams.embedding  = params.embedding; | ||||||
|  |  | ||||||
|   | |||||||
| @@ -68,6 +68,7 @@ struct gpt_params { | |||||||
|     bool perplexity        = false; // compute perplexity over the prompt |     bool perplexity        = false; // compute perplexity over the prompt | ||||||
|     bool use_mmap          = true;  // use mmap for faster loads |     bool use_mmap          = true;  // use mmap for faster loads | ||||||
|     bool use_mlock         = false; // use mlock to keep model in memory |     bool use_mlock         = false; // use mlock to keep model in memory | ||||||
|  |     int gpu_layers         = 0;     // number of layers to store in VRAM | ||||||
|     bool mem_test          = false; // compute maximum memory usage |     bool mem_test          = false; // compute maximum memory usage | ||||||
|     bool verbose_prompt    = false; // print prompt tokens before generation |     bool verbose_prompt    = false; // print prompt tokens before generation | ||||||
| }; | }; | ||||||
|   | |||||||
							
								
								
									
										41
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							
							
						
						
									
										41
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							| @@ -349,7 +349,7 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { | |||||||
| } | } | ||||||
|  |  | ||||||
| // buffer pool for cuda | // buffer pool for cuda | ||||||
| #define MAX_CUDA_BUFFERS 16 | #define MAX_CUDA_BUFFERS 256 | ||||||
|  |  | ||||||
| struct scoped_spin_lock { | struct scoped_spin_lock { | ||||||
|     std::atomic_flag& lock; |     std::atomic_flag& lock; | ||||||
| @@ -678,9 +678,15 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor | |||||||
|             float * c_D = d_D + i * d_ne; |             float * c_D = d_D + i * d_ne; | ||||||
|             char  * c_Q = d_Q + i * q_sz; |             char  * c_Q = d_Q + i * q_sz; | ||||||
|  |  | ||||||
|             if (ne11 == 1) { |             // copy src0 to device if necessary | ||||||
|                 // copy src0 to device |             if (src0->backend == GGML_BACKEND_CPU) { | ||||||
|                 CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2)); |                 CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2)); | ||||||
|  |             } else if (src0->backend == GGML_BACKEND_CUDA) { | ||||||
|  |                 c_Q = ((char *) src0->data) + i * q_sz; | ||||||
|  |             } else { | ||||||
|  |                 GGML_ASSERT(false); | ||||||
|  |             } | ||||||
|  |             if (ne11 == 1) { | ||||||
|                 CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2)); |                 CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2)); | ||||||
|  |  | ||||||
|                 // copy src1 to device |                 // copy src1 to device | ||||||
| @@ -696,8 +702,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor | |||||||
|             } else { |             } else { | ||||||
|                 float * c_X = d_X + i * x_ne; |                 float * c_X = d_X + i * x_ne; | ||||||
|  |  | ||||||
|                 // copy src0 and convert to fp32 on device |                 // convert src0 to fp32 on device | ||||||
|                 CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2)); |  | ||||||
|                 to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2); |                 to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2); | ||||||
|                 CUDA_CHECK(cudaGetLastError()); |                 CUDA_CHECK(cudaGetLastError()); | ||||||
|                 CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2)); |                 CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2)); | ||||||
| @@ -742,8 +747,8 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te | |||||||
|     // TODO: find the optimal values for these |     // TODO: find the optimal values for these | ||||||
|     if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && |     if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && | ||||||
|         src1->type == GGML_TYPE_F32 && |         src1->type == GGML_TYPE_F32 && | ||||||
|         dst->type == GGML_TYPE_F32) { |         dst->type == GGML_TYPE_F32 && | ||||||
|  |         ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_CUDA)) { | ||||||
|         return true; |         return true; | ||||||
|     } |     } | ||||||
|  |  | ||||||
| @@ -795,3 +800,25 @@ size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct | |||||||
|         return 0; |         return 0; | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | void ggml_cuda_transform_tensor(ggml_tensor * tensor) { | ||||||
|  |     const int64_t ne0 = tensor->ne[0]; | ||||||
|  |     const int64_t ne1 = tensor->ne[1]; | ||||||
|  |     const int64_t ne2 = tensor->ne[2]; | ||||||
|  |     const int64_t ne3 = tensor->ne[3]; | ||||||
|  |  | ||||||
|  |     const ggml_type type = tensor->type; | ||||||
|  |     const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type); | ||||||
|  |  | ||||||
|  |     size_t q_size; | ||||||
|  |     char * d_Q = (char *) ggml_cuda_pool_malloc(q_sz, &q_size); | ||||||
|  |  | ||||||
|  |     cudaStream_t cudaStream2 = g_cudaStreams2[0]; | ||||||
|  |  | ||||||
|  |     // copy tensor to device | ||||||
|  |     CUDA_CHECK(ggml_cuda_h2d_tensor_2d(d_Q, tensor, 0, 0, cudaStream2)); | ||||||
|  |     CUDA_CHECK(cudaDeviceSynchronize()); | ||||||
|  |  | ||||||
|  |     tensor->data = d_Q; | ||||||
|  |     tensor->backend = GGML_BACKEND_CUDA; | ||||||
|  | } | ||||||
|   | |||||||
| @@ -14,6 +14,8 @@ void   ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens | |||||||
| void * ggml_cuda_host_malloc(size_t size); | void * ggml_cuda_host_malloc(size_t size); | ||||||
| void   ggml_cuda_host_free(void * ptr); | void   ggml_cuda_host_free(void * ptr); | ||||||
|  |  | ||||||
|  | void ggml_cuda_transform_tensor(struct ggml_tensor * tensor); | ||||||
|  |  | ||||||
| #ifdef  __cplusplus | #ifdef  __cplusplus | ||||||
| } | } | ||||||
| #endif | #endif | ||||||
|   | |||||||
							
								
								
									
										1
									
								
								ggml.c
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								ggml.c
									
									
									
									
									
								
							| @@ -4711,6 +4711,7 @@ struct ggml_tensor * ggml_new_tensor_impl( | |||||||
|  |  | ||||||
|     *result = (struct ggml_tensor) { |     *result = (struct ggml_tensor) { | ||||||
|         /*.type         =*/ type, |         /*.type         =*/ type, | ||||||
|  |         /*.backend      =*/ GGML_BACKEND_CPU, | ||||||
|         /*.n_dims       =*/ n_dims, |         /*.n_dims       =*/ n_dims, | ||||||
|         /*.ne           =*/ { 1, 1, 1, 1 }, |         /*.ne           =*/ { 1, 1, 1, 1 }, | ||||||
|         /*.nb           =*/ { 0, 0, 0, 0 }, |         /*.nb           =*/ { 0, 0, 0, 0 }, | ||||||
|   | |||||||
							
								
								
									
										8
									
								
								ggml.h
									
									
									
									
									
								
							
							
						
						
									
										8
									
								
								ggml.h
									
									
									
									
									
								
							| @@ -243,6 +243,11 @@ extern "C" { | |||||||
|         GGML_TYPE_COUNT, |         GGML_TYPE_COUNT, | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
|  |     enum ggml_backend { | ||||||
|  |         GGML_BACKEND_CPU = 0, | ||||||
|  |         GGML_BACKEND_CUDA = 1, | ||||||
|  |     }; | ||||||
|  |  | ||||||
|     // model file types |     // model file types | ||||||
|     enum ggml_ftype { |     enum ggml_ftype { | ||||||
|         GGML_FTYPE_UNKNOWN     = -1, |         GGML_FTYPE_UNKNOWN     = -1, | ||||||
| @@ -323,6 +328,7 @@ extern "C" { | |||||||
|     // n-dimensional tensor |     // n-dimensional tensor | ||||||
|     struct ggml_tensor { |     struct ggml_tensor { | ||||||
|         enum ggml_type type; |         enum ggml_type type; | ||||||
|  |         enum ggml_backend backend; | ||||||
|  |  | ||||||
|         int     n_dims; |         int     n_dims; | ||||||
|         int64_t ne[GGML_MAX_DIMS]; // number of elements |         int64_t ne[GGML_MAX_DIMS]; // number of elements | ||||||
| @@ -353,7 +359,7 @@ extern "C" { | |||||||
|  |  | ||||||
|         char name[32]; |         char name[32]; | ||||||
|  |  | ||||||
|         char padding[8]; // TODO: remove and add padding to name? |         char padding[9]; // TODO: remove and add padding to name? | ||||||
|     }; |     }; | ||||||
|  |  | ||||||
|     // computation graph |     // computation graph | ||||||
|   | |||||||
							
								
								
									
										22
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										22
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -9,6 +9,9 @@ | |||||||
| #include "llama.h" | #include "llama.h" | ||||||
|  |  | ||||||
| #include "ggml.h" | #include "ggml.h" | ||||||
|  | #ifdef GGML_USE_CUBLAS | ||||||
|  | #include "ggml-cuda.h" | ||||||
|  | #endif | ||||||
|  |  | ||||||
| #include <array> | #include <array> | ||||||
| #include <ctime> | #include <ctime> | ||||||
| @@ -815,6 +818,7 @@ struct llama_context_params llama_context_default_params() { | |||||||
|         /*.vocab_only                  =*/ false, |         /*.vocab_only                  =*/ false, | ||||||
|         /*.use_mmap                    =*/ true, |         /*.use_mmap                    =*/ true, | ||||||
|         /*.use_mlock                   =*/ false, |         /*.use_mlock                   =*/ false, | ||||||
|  |         /*.gpu_layers                  =*/ 0, | ||||||
|         /*.embedding                   =*/ false, |         /*.embedding                   =*/ false, | ||||||
|         /*.progress_callback           =*/ nullptr, |         /*.progress_callback           =*/ nullptr, | ||||||
|         /*.progress_callback_user_data =*/ nullptr, |         /*.progress_callback_user_data =*/ nullptr, | ||||||
| @@ -877,6 +881,7 @@ static void llama_model_load_internal( | |||||||
|         ggml_type memory_type, |         ggml_type memory_type, | ||||||
|         bool use_mmap, |         bool use_mmap, | ||||||
|         bool use_mlock, |         bool use_mlock, | ||||||
|  |         int gpu_layers, | ||||||
|         bool vocab_only, |         bool vocab_only, | ||||||
|         llama_progress_callback progress_callback, |         llama_progress_callback progress_callback, | ||||||
|         void * progress_callback_user_data) { |         void * progress_callback_user_data) { | ||||||
| @@ -1011,6 +1016,18 @@ static void llama_model_load_internal( | |||||||
|     ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : NULL); |     ml->load_all_data(progress_callback, progress_callback_user_data, use_mlock ? &lctx.model.mlock_mmap : NULL); | ||||||
|  |  | ||||||
|     model.mapping = std::move(ml->mapping); |     model.mapping = std::move(ml->mapping); | ||||||
|  | #ifdef GGML_USE_CUBLAS | ||||||
|  |     for (int i = 0; i < std::min(gpu_layers, int(hparams.n_layer)); ++i) { | ||||||
|  |         auto & layer = model.layers[i]; | ||||||
|  |         ggml_cuda_transform_tensor(layer.wq); | ||||||
|  |         ggml_cuda_transform_tensor(layer.wk); | ||||||
|  |         ggml_cuda_transform_tensor(layer.wv); | ||||||
|  |         ggml_cuda_transform_tensor(layer.wo); | ||||||
|  |         ggml_cuda_transform_tensor(layer.w1); | ||||||
|  |         ggml_cuda_transform_tensor(layer.w2); | ||||||
|  |         ggml_cuda_transform_tensor(layer.w3); | ||||||
|  |     } | ||||||
|  | #endif | ||||||
|  |  | ||||||
|     // loading time will be recalculate after the first eval, so |     // loading time will be recalculate after the first eval, so | ||||||
|     // we take page faults deferred by mmap() into consideration |     // we take page faults deferred by mmap() into consideration | ||||||
| @@ -1024,11 +1041,12 @@ static bool llama_model_load( | |||||||
|         ggml_type memory_type, |         ggml_type memory_type, | ||||||
|         bool use_mmap, |         bool use_mmap, | ||||||
|         bool use_mlock, |         bool use_mlock, | ||||||
|  |         int gpu_layers, | ||||||
|         bool vocab_only, |         bool vocab_only, | ||||||
|         llama_progress_callback progress_callback, |         llama_progress_callback progress_callback, | ||||||
|         void *progress_callback_user_data) { |         void *progress_callback_user_data) { | ||||||
|     try { |     try { | ||||||
|         llama_model_load_internal(fname, lctx, n_ctx, memory_type, use_mmap, use_mlock, |         llama_model_load_internal(fname, lctx, n_ctx, memory_type, use_mmap, use_mlock, gpu_layers, | ||||||
|                                   vocab_only, progress_callback, progress_callback_user_data); |                                   vocab_only, progress_callback, progress_callback_user_data); | ||||||
|         return true; |         return true; | ||||||
|     } catch (const std::string & err) { |     } catch (const std::string & err) { | ||||||
| @@ -2088,7 +2106,7 @@ struct llama_context * llama_init_from_file( | |||||||
|     ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; |     ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; | ||||||
|  |  | ||||||
|     if (!llama_model_load(path_model, *ctx, params.n_ctx, memory_type, |     if (!llama_model_load(path_model, *ctx, params.n_ctx, memory_type, | ||||||
|                           params.use_mmap, params.use_mlock, params.vocab_only, |                           params.use_mmap, params.use_mlock, params.gpu_layers, params.vocab_only, | ||||||
|                           params.progress_callback, params.progress_callback_user_data)) { |                           params.progress_callback, params.progress_callback_user_data)) { | ||||||
|         fprintf(stderr, "%s: failed to load model\n", __func__); |         fprintf(stderr, "%s: failed to load model\n", __func__); | ||||||
|         llama_free(ctx); |         llama_free(ctx); | ||||||
|   | |||||||
							
								
								
									
										1
									
								
								llama.h
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								llama.h
									
									
									
									
									
								
							| @@ -63,6 +63,7 @@ extern "C" { | |||||||
|         bool vocab_only; // only load the vocabulary, no weights |         bool vocab_only; // only load the vocabulary, no weights | ||||||
|         bool use_mmap;   // use mmap if possible |         bool use_mmap;   // use mmap if possible | ||||||
|         bool use_mlock;  // force system to keep model in RAM |         bool use_mlock;  // force system to keep model in RAM | ||||||
|  |         int gpu_layers;  // number of layers to store in VRAM | ||||||
|         bool embedding;  // embedding mode only |         bool embedding;  // embedding mode only | ||||||
|  |  | ||||||
|         // called with a progress value between 0 and 1, pass NULL to disable |         // called with a progress value between 0 and 1, pass NULL to disable | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 JohannesGaessler
					JohannesGaessler