mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	CUDA: mmq CLI option, fixed mmq build issues (#2453)
This commit is contained in:
		| @@ -68,7 +68,7 @@ option(LLAMA_ACCELERATE                      "llama: enable Accelerate framework | |||||||
| option(LLAMA_BLAS                            "llama: use BLAS"                                  OFF) | option(LLAMA_BLAS                            "llama: use BLAS"                                  OFF) | ||||||
| set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor") | set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor") | ||||||
| option(LLAMA_CUBLAS                          "llama: use CUDA"                                  OFF) | option(LLAMA_CUBLAS                          "llama: use CUDA"                                  OFF) | ||||||
| option(LLAMA_CUDA_CUBLAS                     "llama: use cuBLAS for prompt processing"          OFF) | #option(LLAMA_CUDA_CUBLAS                     "llama: use cuBLAS for prompt processing"          OFF) | ||||||
| set(LLAMA_CUDA_MMQ_Y       "64" CACHE STRING "llama: y tile size for mmq CUDA kernels") | set(LLAMA_CUDA_MMQ_Y       "64" CACHE STRING "llama: y tile size for mmq CUDA kernels") | ||||||
| option(LLAMA_CUDA_FORCE_DMMV                 "llama: use dmmv instead of mmvq CUDA kernels"     OFF) | option(LLAMA_CUDA_FORCE_DMMV                 "llama: use dmmv instead of mmvq CUDA kernels"     OFF) | ||||||
| set(LLAMA_CUDA_DMMV_X      "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") | set(LLAMA_CUDA_DMMV_X      "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") | ||||||
| @@ -253,9 +253,9 @@ if (LLAMA_CUBLAS) | |||||||
|         set(GGML_SOURCES_CUDA ggml-cuda.cu ggml-cuda.h) |         set(GGML_SOURCES_CUDA ggml-cuda.cu ggml-cuda.h) | ||||||
|  |  | ||||||
|         add_compile_definitions(GGML_USE_CUBLAS) |         add_compile_definitions(GGML_USE_CUBLAS) | ||||||
|         if (LLAMA_CUDA_CUBLAS) | #        if (LLAMA_CUDA_CUBLAS) | ||||||
|             add_compile_definitions(GGML_CUDA_CUBLAS) | #            add_compile_definitions(GGML_CUDA_CUBLAS) | ||||||
|         endif() | #        endif() | ||||||
|         add_compile_definitions(GGML_CUDA_MMQ_Y=${LLAMA_CUDA_MMQ_Y}) |         add_compile_definitions(GGML_CUDA_MMQ_Y=${LLAMA_CUDA_MMQ_Y}) | ||||||
|         if (LLAMA_CUDA_FORCE_DMMV) |         if (LLAMA_CUDA_FORCE_DMMV) | ||||||
|             add_compile_definitions(GGML_CUDA_FORCE_DMMV) |             add_compile_definitions(GGML_CUDA_FORCE_DMMV) | ||||||
| @@ -277,10 +277,14 @@ if (LLAMA_CUBLAS) | |||||||
|         endif() |         endif() | ||||||
|  |  | ||||||
|     if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES) |     if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES) | ||||||
|  |         # 52 == lowest CUDA 12 standard | ||||||
|  |         # 60 == f16 CUDA intrinsics | ||||||
|  |         # 61 == integer CUDA intrinsics | ||||||
|  |         # 70 == (assumed) compute capability at which unrolling a loop in mul_mat_q kernels is faster | ||||||
|         if (LLAMA_CUDA_DMMV_F16) |         if (LLAMA_CUDA_DMMV_F16) | ||||||
|             set(CMAKE_CUDA_ARCHITECTURES "60;61") # needed for f16 CUDA intrinsics |             set(CMAKE_CUDA_ARCHITECTURES "60;61;70") # needed for f16 CUDA intrinsics | ||||||
|         else() |         else() | ||||||
|             set(CMAKE_CUDA_ARCHITECTURES "52;61") # lowest CUDA 12 standard + lowest for integer intrinsics |             set(CMAKE_CUDA_ARCHITECTURES "52;61;70") # lowest CUDA 12 standard + lowest for integer intrinsics | ||||||
|         endif() |         endif() | ||||||
|     endif() |     endif() | ||||||
|     message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") |     message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") | ||||||
|   | |||||||
							
								
								
									
										6
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										6
									
								
								Makefile
									
									
									
									
									
								
							| @@ -236,9 +236,9 @@ ifdef LLAMA_CUDA_MMQ_Y | |||||||
| else | else | ||||||
| 	NVCCFLAGS += -DGGML_CUDA_MMQ_Y=64 | 	NVCCFLAGS += -DGGML_CUDA_MMQ_Y=64 | ||||||
| endif # LLAMA_CUDA_MMQ_Y | endif # LLAMA_CUDA_MMQ_Y | ||||||
| ifdef LLAMA_CUDA_CUBLAS | #ifdef LLAMA_CUDA_CUBLAS | ||||||
| 	NVCCFLAGS += -DGGML_CUDA_CUBLAS | #	NVCCFLAGS += -DGGML_CUDA_CUBLAS | ||||||
| endif # LLAMA_CUDA_CUBLAS | #endif # LLAMA_CUDA_CUBLAS | ||||||
| ifdef LLAMA_CUDA_CCBIN | ifdef LLAMA_CUDA_CCBIN | ||||||
| 	NVCCFLAGS += -ccbin $(LLAMA_CUDA_CCBIN) | 	NVCCFLAGS += -ccbin $(LLAMA_CUDA_CCBIN) | ||||||
| endif | endif | ||||||
|   | |||||||
| @@ -400,9 +400,11 @@ Building the program with BLAS support may lead to some performance improvements | |||||||
|  |  | ||||||
|   The environment variable [`CUDA_VISIBLE_DEVICES`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars) can be used to specify which GPU(s) will be used. The following compilation options are also available to tweak performance: |   The environment variable [`CUDA_VISIBLE_DEVICES`](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars) can be used to specify which GPU(s) will be used. The following compilation options are also available to tweak performance: | ||||||
|  |  | ||||||
|  | <!--- | ||||||
|  |   | LLAMA_CUDA_CUBLAS       | Boolean                |   false | Use cuBLAS instead of custom CUDA kernels for prompt processing. Faster for all quantization formats except for q4_0 and q8_0, especially for k-quants. Increases VRAM usage (700 MiB for 7b, 970 MiB for 13b, 1430 MiB for 33b). | | ||||||
|  | ---> | ||||||
|   | Option                  | Legal values           | Default | Description | |   | Option                  | Legal values           | Default | Description | | ||||||
|   |-------------------------|------------------------|---------|-------------| |   |-------------------------|------------------------|---------|-------------| | ||||||
|   | LLAMA_CUDA_CUBLAS       | Boolean                |   false | Use cuBLAS instead of custom CUDA kernels for prompt processing. Faster for all quantization formats except for q4_0 and q8_0, especially for k-quants. Increases VRAM usage (700 MiB for 7b, 970 MiB for 13b, 1430 MiB for 33b). | |  | ||||||
|   | LLAMA_CUDA_MMQ_Y        | Positive integer >= 32 |      64 | Tile size in y direction when using the custom CUDA kernels for prompt processing. Higher values can be faster depending on the amount of shared memory available. Power of 2 heavily recommended. | |   | LLAMA_CUDA_MMQ_Y        | Positive integer >= 32 |      64 | Tile size in y direction when using the custom CUDA kernels for prompt processing. Higher values can be faster depending on the amount of shared memory available. Power of 2 heavily recommended. | | ||||||
|   | LLAMA_CUDA_FORCE_DMMV   | Boolean                |   false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. | |   | LLAMA_CUDA_FORCE_DMMV   | Boolean                |   false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. | | ||||||
|   | LLAMA_CUDA_DMMV_X       | Positive integer >= 32 |      32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. | |   | LLAMA_CUDA_DMMV_X       | Positive integer >= 32 |      32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. | | ||||||
|   | |||||||
| @@ -377,6 +377,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { | |||||||
|             } |             } | ||||||
| #else | #else | ||||||
|             fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n"); |             fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n"); | ||||||
|  | #endif // GGML_USE_CUBLAS | ||||||
|  |         } else if (arg == "--mul-mat-q" || arg == "-mmq") { | ||||||
|  | #ifdef GGML_USE_CUBLAS | ||||||
|  |             params.mul_mat_q = true; | ||||||
|  | #else | ||||||
|  |             fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to use mul_mat_q kernels.\n"); | ||||||
| #endif // GGML_USE_CUBLAS | #endif // GGML_USE_CUBLAS | ||||||
|         } else if (arg == "--low-vram" || arg == "-lv") { |         } else if (arg == "--low-vram" || arg == "-lv") { | ||||||
| #ifdef GGML_USE_CUBLAS | #ifdef GGML_USE_CUBLAS | ||||||
| @@ -585,6 +591,9 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { | |||||||
|     fprintf(stdout, "                        how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n"); |     fprintf(stdout, "                        how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n"); | ||||||
|     fprintf(stdout, "  -mg i, --main-gpu i   the GPU to use for scratch and small tensors\n" ); |     fprintf(stdout, "  -mg i, --main-gpu i   the GPU to use for scratch and small tensors\n" ); | ||||||
|     fprintf(stdout, "  -lv, --low-vram       don't allocate VRAM scratch buffer\n" ); |     fprintf(stdout, "  -lv, --low-vram       don't allocate VRAM scratch buffer\n" ); | ||||||
|  |     fprintf(stdout, "  -mmq, --mul-mat-q     use experimental mul_mat_q CUDA kernels instead of cuBLAS. TEMP!!!\n" ); | ||||||
|  |     fprintf(stdout, "                        Reduces VRAM usage by 700/970/1430 MiB for 7b/13b/33b but prompt processing speed\n" ); | ||||||
|  |     fprintf(stdout, "                        is still suboptimal, especially q2_K, q3_K, q5_K, and q6_K.\n" ); | ||||||
| #endif | #endif | ||||||
|     fprintf(stdout, "  --mtest               compute maximum memory usage\n"); |     fprintf(stdout, "  --mtest               compute maximum memory usage\n"); | ||||||
|     fprintf(stdout, "  --export              export the computation graph to 'llama.ggml'\n"); |     fprintf(stdout, "  --export              export the computation graph to 'llama.ggml'\n"); | ||||||
| @@ -637,6 +646,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param | |||||||
|     lparams.main_gpu        = params.main_gpu; |     lparams.main_gpu        = params.main_gpu; | ||||||
|     lparams.tensor_split    = params.tensor_split; |     lparams.tensor_split    = params.tensor_split; | ||||||
|     lparams.low_vram        = params.low_vram; |     lparams.low_vram        = params.low_vram; | ||||||
|  |     lparams.mul_mat_q       = params.mul_mat_q; | ||||||
|     lparams.seed            = params.seed; |     lparams.seed            = params.seed; | ||||||
|     lparams.f16_kv          = params.memory_f16; |     lparams.f16_kv          = params.memory_f16; | ||||||
|     lparams.use_mmap        = params.use_mmap; |     lparams.use_mmap        = params.use_mmap; | ||||||
|   | |||||||
| @@ -74,6 +74,7 @@ struct gpt_params { | |||||||
|     size_t hellaswag_tasks = 400;   // number of tasks to use when computing the HellaSwag score |     size_t hellaswag_tasks = 400;   // number of tasks to use when computing the HellaSwag score | ||||||
|  |  | ||||||
|     bool low_vram          = false; // if true, reduce VRAM usage at the cost of performance |     bool low_vram          = false; // if true, reduce VRAM usage at the cost of performance | ||||||
|  |     bool mul_mat_q         = false; // if true, use experimental mul_mat_q kernels | ||||||
|     bool memory_f16        = true;  // use f16 instead of f32 for memory kv |     bool memory_f16        = true;  // use f16 instead of f32 for memory kv | ||||||
|     bool random_prompt     = false; // do not randomize prompt if none provided |     bool random_prompt     = false; // do not randomize prompt if none provided | ||||||
|     bool use_color         = false; // use color to distinguish generations and inputs |     bool use_color         = false; // use color to distinguish generations and inputs | ||||||
|   | |||||||
| @@ -631,6 +631,9 @@ static void server_print_usage(const char *argv0, const gpt_params ¶ms, | |||||||
|     fprintf(stdout, "                        how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n"); |     fprintf(stdout, "                        how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n"); | ||||||
|     fprintf(stdout, "  -mg i, --main-gpu i   the GPU to use for scratch and small tensors\n"); |     fprintf(stdout, "  -mg i, --main-gpu i   the GPU to use for scratch and small tensors\n"); | ||||||
|     fprintf(stdout, "  -lv, --low-vram don't allocate VRAM scratch buffer\n"); |     fprintf(stdout, "  -lv, --low-vram don't allocate VRAM scratch buffer\n"); | ||||||
|  |     fprintf(stdout, "  -mmq, --mul-mat-q     use experimental mul_mat_q CUDA kernels instead of cuBLAS. TEMP!!!\n" ); | ||||||
|  |     fprintf(stdout, "                        Reduces VRAM usage by 700/970/1430 MiB for 7b/13b/33b but prompt processing speed\n" ); | ||||||
|  |     fprintf(stdout, "                        is still suboptimal, especially q2_K, q3_K, q5_K, and q6_K.\n" ); | ||||||
| #endif | #endif | ||||||
|     fprintf(stdout, "  -m FNAME, --model FNAME\n"); |     fprintf(stdout, "  -m FNAME, --model FNAME\n"); | ||||||
|     fprintf(stdout, "                        model path (default: %s)\n", params.model.c_str()); |     fprintf(stdout, "                        model path (default: %s)\n", params.model.c_str()); | ||||||
| @@ -827,7 +830,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, | |||||||
|                 } |                 } | ||||||
|             } |             } | ||||||
| #else | #else | ||||||
|             LOG_WARNING("llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.", {}); |             LOG_WARNING("llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n", {}); | ||||||
| #endif // GGML_USE_CUBLAS | #endif // GGML_USE_CUBLAS | ||||||
|         } |         } | ||||||
|         else if (arg == "--low-vram" || arg == "-lv") |         else if (arg == "--low-vram" || arg == "-lv") | ||||||
| @@ -835,7 +838,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, | |||||||
| #ifdef GGML_USE_CUBLAS | #ifdef GGML_USE_CUBLAS | ||||||
|             params.low_vram = true; |             params.low_vram = true; | ||||||
| #else | #else | ||||||
|             fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n"); |             LOG_WARNING("warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n", {}); | ||||||
|  | #endif // GGML_USE_CUBLAS | ||||||
|  |         } | ||||||
|  |         else if (arg == "--mul-mat-q" || arg == "-mmq") | ||||||
|  |         { | ||||||
|  | #ifdef GGML_USE_CUBLAS | ||||||
|  |             params.mul_mat_q = true; | ||||||
|  | #else | ||||||
|  |             LOG_WARNING("warning: llama.cpp was compiled without cuBLAS. It is not possible to use mul_mat_q kernels.\n", {}); | ||||||
| #endif // GGML_USE_CUBLAS | #endif // GGML_USE_CUBLAS | ||||||
|         } |         } | ||||||
|         else if (arg == "--main-gpu" || arg == "-mg") |         else if (arg == "--main-gpu" || arg == "-mg") | ||||||
|   | |||||||
							
								
								
									
										24
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							
							
						
						
									
										24
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							| @@ -3898,10 +3898,9 @@ static size_t g_scratch_offset = 0; | |||||||
|  |  | ||||||
| static int g_device_count = -1; | static int g_device_count = -1; | ||||||
| static int g_main_device = 0; | static int g_main_device = 0; | ||||||
| #ifndef GGML_CUDA_FORCE_DMMV |  | ||||||
| static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES]; | static int g_compute_capabilities[GGML_CUDA_MAX_DEVICES]; | ||||||
| #endif |  | ||||||
| static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0}; | static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0}; | ||||||
|  | static bool g_mul_mat_q = false; | ||||||
|  |  | ||||||
| static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; | static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; | ||||||
|  |  | ||||||
| @@ -3923,9 +3922,7 @@ void ggml_init_cublas() { | |||||||
|             g_tensor_split[id] = total_vram; |             g_tensor_split[id] = total_vram; | ||||||
|             total_vram += prop.totalGlobalMem; |             total_vram += prop.totalGlobalMem; | ||||||
|  |  | ||||||
| #ifndef GGML_CUDA_FORCE_DMMV |  | ||||||
|             g_compute_capabilities[id] = 100*prop.major + 10*prop.minor; |             g_compute_capabilities[id] = 100*prop.major + 10*prop.minor; | ||||||
| #endif |  | ||||||
|         } |         } | ||||||
|         for (int id = 0; id < g_device_count; ++id) { |         for (int id = 0; id < g_device_count; ++id) { | ||||||
|             g_tensor_split[id] /= total_vram; |             g_tensor_split[id] /= total_vram; | ||||||
| @@ -4278,6 +4275,7 @@ inline void ggml_cuda_op_mul_mat_vec( | |||||||
|  |  | ||||||
| #ifdef GGML_CUDA_FORCE_DMMV | #ifdef GGML_CUDA_FORCE_DMMV | ||||||
|     const bool use_mul_mat_vec_q = false; |     const bool use_mul_mat_vec_q = false; | ||||||
|  |     (void) g_compute_capabilities[0]; | ||||||
| #else | #else | ||||||
|     int id; |     int id; | ||||||
|     CUDA_CHECK(cudaGetDevice(&id)); |     CUDA_CHECK(cudaGetDevice(&id)); | ||||||
| @@ -5021,12 +5019,14 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ | |||||||
|         if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) { |         if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) { | ||||||
|             ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_vec, false, false); |             ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_vec, false, false); | ||||||
|         } else { |         } else { | ||||||
| #ifdef GGML_CUDA_CUBLAS |             int min_compute_capability = INT_MAX; | ||||||
|             const bool use_mul_mat_q = false; |             for (int id = 0; id < g_device_count; ++id) { | ||||||
| #else |                 if (min_compute_capability > g_compute_capabilities[id]) { | ||||||
|             const bool use_mul_mat_q = ggml_is_quantized(src0->type); |                     min_compute_capability = g_compute_capabilities[id]; | ||||||
| #endif // GGML_CUDA_CUBLAS |                 } | ||||||
|             if (use_mul_mat_q) { |             } | ||||||
|  |  | ||||||
|  |             if (g_mul_mat_q && ggml_is_quantized(src0->type) && min_compute_capability >= MIN_CC_DP4A) { | ||||||
|                 ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_q, false, false); |                 ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_q, false, false); | ||||||
|             } else { |             } else { | ||||||
|                 ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false); |                 ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true, false); | ||||||
| @@ -5320,6 +5320,10 @@ void ggml_cuda_set_main_device(int main_device) { | |||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | void ggml_cuda_set_mul_mat_q(bool mul_mat_q) { | ||||||
|  |     g_mul_mat_q = mul_mat_q; | ||||||
|  | } | ||||||
|  |  | ||||||
| void ggml_cuda_set_scratch_size(size_t scratch_size) { | void ggml_cuda_set_scratch_size(size_t scratch_size) { | ||||||
|     g_scratch_size = scratch_size; |     g_scratch_size = scratch_size; | ||||||
| } | } | ||||||
|   | |||||||
| @@ -27,6 +27,7 @@ void   ggml_cuda_assign_buffers(struct ggml_tensor * tensor); | |||||||
| void   ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor); | void   ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor); | ||||||
| void   ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor); | void   ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor); | ||||||
| void   ggml_cuda_set_main_device(int main_device); | void   ggml_cuda_set_main_device(int main_device); | ||||||
|  | void   ggml_cuda_set_mul_mat_q(bool mul_mat_q); | ||||||
| void   ggml_cuda_set_scratch_size(size_t scratch_size); | void   ggml_cuda_set_scratch_size(size_t scratch_size); | ||||||
| void   ggml_cuda_free_scratch(void); | void   ggml_cuda_free_scratch(void); | ||||||
| bool   ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor); | bool   ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor); | ||||||
|   | |||||||
							
								
								
									
										10
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -901,6 +901,7 @@ struct llama_context_params llama_context_default_params() { | |||||||
|         /*.progress_callback           =*/ nullptr, |         /*.progress_callback           =*/ nullptr, | ||||||
|         /*.progress_callback_user_data =*/ nullptr, |         /*.progress_callback_user_data =*/ nullptr, | ||||||
|         /*.low_vram                    =*/ false, |         /*.low_vram                    =*/ false, | ||||||
|  |         /*.mul_mat_q                   =*/ false, | ||||||
|         /*.f16_kv                      =*/ true, |         /*.f16_kv                      =*/ true, | ||||||
|         /*.logits_all                  =*/ false, |         /*.logits_all                  =*/ false, | ||||||
|         /*.vocab_only                  =*/ false, |         /*.vocab_only                  =*/ false, | ||||||
| @@ -1028,6 +1029,7 @@ static void llama_model_load_internal( | |||||||
|         int n_gpu_layers, |         int n_gpu_layers, | ||||||
|         int main_gpu, |         int main_gpu, | ||||||
|         const float * tensor_split, |         const float * tensor_split, | ||||||
|  |         const bool mul_mat_q, | ||||||
|         float rope_freq_base, |         float rope_freq_base, | ||||||
|         float rope_freq_scale, |         float rope_freq_scale, | ||||||
|         bool low_vram, |         bool low_vram, | ||||||
| @@ -1156,9 +1158,11 @@ static void llama_model_load_internal( | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     (void) main_gpu; |     (void) main_gpu; | ||||||
|  |     (void) mul_mat_q; | ||||||
| #if defined(GGML_USE_CUBLAS) | #if defined(GGML_USE_CUBLAS) | ||||||
|     fprintf(stderr, "%s: using CUDA for GPU acceleration\n", __func__); |     fprintf(stderr, "%s: using CUDA for GPU acceleration\n", __func__); | ||||||
|     ggml_cuda_set_main_device(main_gpu); |     ggml_cuda_set_main_device(main_gpu); | ||||||
|  |     ggml_cuda_set_mul_mat_q(mul_mat_q); | ||||||
| #define LLAMA_BACKEND_OFFLOAD       GGML_BACKEND_GPU | #define LLAMA_BACKEND_OFFLOAD       GGML_BACKEND_GPU | ||||||
| #define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU_SPLIT | #define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU_SPLIT | ||||||
| #elif defined(GGML_USE_CLBLAST) | #elif defined(GGML_USE_CLBLAST) | ||||||
| @@ -1367,6 +1371,7 @@ static bool llama_model_load( | |||||||
|         int n_gpu_layers, |         int n_gpu_layers, | ||||||
|         int main_gpu, |         int main_gpu, | ||||||
|         const float * tensor_split, |         const float * tensor_split, | ||||||
|  |         const bool mul_mat_q, | ||||||
|         float rope_freq_base, |         float rope_freq_base, | ||||||
|         float rope_freq_scale, |         float rope_freq_scale, | ||||||
|         bool low_vram, |         bool low_vram, | ||||||
| @@ -1377,7 +1382,8 @@ static bool llama_model_load( | |||||||
|         llama_progress_callback progress_callback, |         llama_progress_callback progress_callback, | ||||||
|         void *progress_callback_user_data) { |         void *progress_callback_user_data) { | ||||||
|     try { |     try { | ||||||
|         llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, rms_norm_eps, n_gpu_layers, main_gpu, tensor_split, rope_freq_base, rope_freq_scale, low_vram, memory_type, |         llama_model_load_internal(fname, model, vocab, n_ctx, n_batch, n_gqa, rms_norm_eps, n_gpu_layers, | ||||||
|  |                                   main_gpu, tensor_split, mul_mat_q, rope_freq_base, rope_freq_scale, low_vram, memory_type, | ||||||
|                                   use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data); |                                   use_mmap, use_mlock, vocab_only, progress_callback, progress_callback_user_data); | ||||||
|         return true; |         return true; | ||||||
|     } catch (const std::exception & err) { |     } catch (const std::exception & err) { | ||||||
| @@ -3192,7 +3198,7 @@ struct llama_model * llama_load_model_from_file( | |||||||
|     ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; |     ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; | ||||||
|  |  | ||||||
|     if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gqa, params.rms_norm_eps, params.n_gpu_layers, |     if (!llama_model_load(path_model, *model, model->vocab, params.n_ctx, params.n_batch, params.n_gqa, params.rms_norm_eps, params.n_gpu_layers, | ||||||
|                 params.main_gpu, params.tensor_split, params.rope_freq_base, params.rope_freq_scale,params.low_vram, |                 params.main_gpu, params.tensor_split, params.mul_mat_q, params.rope_freq_base, params.rope_freq_scale,params.low_vram, | ||||||
|                 memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback, |                 memory_type, params.use_mmap, params.use_mlock, params.vocab_only, params.progress_callback, | ||||||
|                 params.progress_callback_user_data)) { |                 params.progress_callback_user_data)) { | ||||||
|         delete model; |         delete model; | ||||||
|   | |||||||
							
								
								
									
										1
									
								
								llama.h
									
									
									
									
									
								
							
							
						
						
									
										1
									
								
								llama.h
									
									
									
									
									
								
							| @@ -108,6 +108,7 @@ extern "C" { | |||||||
|  |  | ||||||
|         // Keep the booleans together to avoid misalignment during copy-by-value. |         // Keep the booleans together to avoid misalignment during copy-by-value. | ||||||
|         bool low_vram;   // if true, reduce VRAM usage at the cost of performance |         bool low_vram;   // if true, reduce VRAM usage at the cost of performance | ||||||
|  |         bool mul_mat_q;  // if true, use experimental mul_mat_q kernels | ||||||
|         bool f16_kv;     // use fp16 for KV cache |         bool f16_kv;     // use fp16 for KV cache | ||||||
|         bool logits_all; // the llama_eval() call computes all logits, not just the last one |         bool logits_all; // the llama_eval() call computes all logits, not just the last one | ||||||
|         bool vocab_only; // only load the vocabulary, no weights |         bool vocab_only; // only load the vocabulary, no weights | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Johannes Gäßler
					Johannes Gäßler