mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	cuda : supports running on CPU for GGML_USE_CUBLAS=ON build (#3946)
* protyping the idea that supports running on CPU for a GGML_USE_CUBLAS=on build * doc: add comments to ggml_cublas_loaded() * fix defined(...)
This commit is contained in:
		
							
								
								
									
										17
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							
							
						
						
									
										17
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							| @@ -5790,6 +5790,11 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) { | |||||||
|     CUDA_CHECK(cudaFree(ptr)); |     CUDA_CHECK(cudaFree(ptr)); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | static bool g_cublas_loaded = false; | ||||||
|  |  | ||||||
|  | bool ggml_cublas_loaded(void) { | ||||||
|  |     return g_cublas_loaded; | ||||||
|  | } | ||||||
|  |  | ||||||
| void ggml_init_cublas() { | void ggml_init_cublas() { | ||||||
|     static bool initialized = false; |     static bool initialized = false; | ||||||
| @@ -5803,7 +5808,12 @@ void ggml_init_cublas() { | |||||||
|         CUDA_CHECK(cudaDeviceSynchronize()); |         CUDA_CHECK(cudaDeviceSynchronize()); | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
|         CUDA_CHECK(cudaGetDeviceCount(&g_device_count)); |         if (cudaGetDeviceCount(&g_device_count) != cudaSuccess) { | ||||||
|  |             initialized = true; | ||||||
|  |             g_cublas_loaded = false; | ||||||
|  |             return; | ||||||
|  |         } | ||||||
|  |  | ||||||
|         GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES); |         GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES); | ||||||
|         int64_t total_vram = 0; |         int64_t total_vram = 0; | ||||||
| #if defined(GGML_CUDA_FORCE_MMQ) | #if defined(GGML_CUDA_FORCE_MMQ) | ||||||
| @@ -5851,6 +5861,7 @@ void ggml_init_cublas() { | |||||||
|         // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr)); |         // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr)); | ||||||
|  |  | ||||||
|         initialized = true; |         initialized = true; | ||||||
|  |         g_cublas_loaded = true; | ||||||
|     } |     } | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -7158,6 +7169,8 @@ static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src | |||||||
| } | } | ||||||
|  |  | ||||||
| bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { | bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { | ||||||
|  |     if (!g_cublas_loaded) return false; | ||||||
|  |  | ||||||
|     const int64_t ne10 = src1->ne[0]; |     const int64_t ne10 = src1->ne[0]; | ||||||
|  |  | ||||||
|     const int64_t ne0 = dst->ne[0]; |     const int64_t ne0 = dst->ne[0]; | ||||||
| @@ -7843,6 +7856,8 @@ void ggml_cuda_free_scratch() { | |||||||
| } | } | ||||||
|  |  | ||||||
| bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { | bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { | ||||||
|  |     if (!g_cublas_loaded) return false; | ||||||
|  |  | ||||||
|     ggml_cuda_func_t func; |     ggml_cuda_func_t func; | ||||||
|     const bool any_on_device = tensor->backend == GGML_BACKEND_GPU |     const bool any_on_device = tensor->backend == GGML_BACKEND_GPU | ||||||
|         || (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) |         || (tensor->src[0] != nullptr && (tensor->src[0]->backend == GGML_BACKEND_GPU || tensor->src[0]->backend == GGML_BACKEND_GPU_SPLIT)) | ||||||
|   | |||||||
| @@ -17,7 +17,12 @@ extern "C" { | |||||||
|  |  | ||||||
| #define GGML_CUDA_MAX_DEVICES       16 | #define GGML_CUDA_MAX_DEVICES       16 | ||||||
|  |  | ||||||
|  | // Always success. To check if CUDA is actually loaded, use `ggml_cublas_loaded`. | ||||||
| GGML_API void   ggml_init_cublas(void); | GGML_API void   ggml_init_cublas(void); | ||||||
|  |  | ||||||
|  | // Returns `true` if there are available CUDA devices and cublas loads successfully; otherwise, it returns `false`. | ||||||
|  | GGML_API bool   ggml_cublas_loaded(void); | ||||||
|  |  | ||||||
| GGML_API void * ggml_cuda_host_malloc(size_t size); | GGML_API void * ggml_cuda_host_malloc(size_t size); | ||||||
| GGML_API void   ggml_cuda_host_free(void * ptr); | GGML_API void   ggml_cuda_host_free(void * ptr); | ||||||
|  |  | ||||||
|   | |||||||
							
								
								
									
										181
									
								
								llama.cpp
									
									
									
									
									
								
							
							
						
						
									
										181
									
								
								llama.cpp
									
									
									
									
									
								
							| @@ -596,19 +596,37 @@ static void ggml_graph_compute_helper(std::vector<uint8_t> & buf, ggml_cgraph * | |||||||
| // llama helpers | // llama helpers | ||||||
| // | // | ||||||
|  |  | ||||||
|  | inline void * llama_host_malloc(size_t n) { | ||||||
| #ifdef GGML_USE_CUBLAS | #ifdef GGML_USE_CUBLAS | ||||||
| #   define llama_host_malloc(n)  ggml_cuda_host_malloc(n) |     if (ggml_cublas_loaded()) { | ||||||
| #   define llama_host_free(data) ggml_cuda_host_free(data) |         return ggml_cuda_host_malloc(n); | ||||||
|  |     } else { | ||||||
|  |         return malloc(n); | ||||||
|  |     } | ||||||
| #elif GGML_USE_METAL | #elif GGML_USE_METAL | ||||||
| #   define llama_host_malloc(n)  ggml_metal_host_malloc(n) |     return ggml_metal_host_malloc(n); | ||||||
| #   define llama_host_free(data) ggml_metal_host_free(data) |  | ||||||
| #elif GGML_USE_CPU_HBM | #elif GGML_USE_CPU_HBM | ||||||
| #   define llama_host_malloc(n)  hbw_malloc(n) |     return hbw_malloc(n); | ||||||
| #   define llama_host_free(data) if (data != NULL) hbw_free(data) |  | ||||||
| #else | #else | ||||||
| #   define llama_host_malloc(n)  malloc(n) |     return malloc(n); | ||||||
| #   define llama_host_free(data) free(data) |  | ||||||
| #endif | #endif | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline void llama_host_free(void * ptr) { | ||||||
|  | #ifdef GGML_USE_CUBLAS | ||||||
|  |     if (ggml_cublas_loaded()) { | ||||||
|  |         return ggml_cuda_host_free(ptr); | ||||||
|  |     } else { | ||||||
|  |         return free(ptr); | ||||||
|  |     } | ||||||
|  | #elif GGML_USE_METAL | ||||||
|  |     return ggml_metal_host_free(ptr); | ||||||
|  | #elif GGML_USE_CPU_HBM | ||||||
|  |     return hbw_free(ptr); | ||||||
|  | #else | ||||||
|  |     return free(ptr); | ||||||
|  | #endif | ||||||
|  | } | ||||||
|  |  | ||||||
| #if defined(_WIN32) | #if defined(_WIN32) | ||||||
| static std::string llama_format_win_err(DWORD err) { | static std::string llama_format_win_err(DWORD err) { | ||||||
| @@ -1200,9 +1218,11 @@ struct llama_kv_cache { | |||||||
|         } |         } | ||||||
|  |  | ||||||
| #ifdef GGML_USE_CUBLAS | #ifdef GGML_USE_CUBLAS | ||||||
|         ggml_cuda_free_data(k); |         if (ggml_cublas_loaded()) { | ||||||
|         ggml_cuda_free_data(v); |             ggml_cuda_free_data(k); | ||||||
| #endif // GGML_USE_CUBLAS |             ggml_cuda_free_data(v); | ||||||
|  |         } | ||||||
|  | #endif | ||||||
|     } |     } | ||||||
| }; | }; | ||||||
|  |  | ||||||
| @@ -1302,11 +1322,15 @@ struct llama_model { | |||||||
|         } |         } | ||||||
|  |  | ||||||
| #ifdef GGML_USE_CUBLAS | #ifdef GGML_USE_CUBLAS | ||||||
|         for (size_t i = 0; i < tensors_by_name.size(); ++i) { |         if (ggml_cublas_loaded()) { | ||||||
|             ggml_cuda_free_data(tensors_by_name[i].second); |             for (size_t i = 0; i < tensors_by_name.size(); ++i) { | ||||||
|  |                 ggml_cuda_free_data(tensors_by_name[i].second); | ||||||
|  |             } | ||||||
|  |             ggml_cuda_free_scratch(); | ||||||
|         } |         } | ||||||
|         ggml_cuda_free_scratch(); | #endif | ||||||
| #elif defined(GGML_USE_CLBLAST) |  | ||||||
|  | #if defined(GGML_USE_CLBLAST) | ||||||
|         for (size_t i = 0; i < tensors_by_name.size(); ++i) { |         for (size_t i = 0; i < tensors_by_name.size(); ++i) { | ||||||
|             ggml_cl_free_data(tensors_by_name[i].second); |             ggml_cl_free_data(tensors_by_name[i].second); | ||||||
|         } |         } | ||||||
| @@ -1418,23 +1442,26 @@ static bool llama_kv_cache_init( | |||||||
|     ggml_set_name(cache.v, "cache_v"); |     ggml_set_name(cache.v, "cache_v"); | ||||||
|  |  | ||||||
|     (void) n_gpu_layers; |     (void) n_gpu_layers; | ||||||
| #ifdef GGML_USE_CUBLAS |  | ||||||
|     size_t vram_kv_cache = 0; |  | ||||||
|  |  | ||||||
|     if (n_gpu_layers > (int)n_layer + 1) { | #ifdef GGML_USE_CUBLAS | ||||||
|         ggml_cuda_assign_buffers_no_scratch(cache.v); |     if (ggml_cublas_loaded()) { | ||||||
|         LLAMA_LOG_INFO("%s: offloading v cache to GPU\n", __func__); |         size_t vram_kv_cache = 0; | ||||||
|         vram_kv_cache += ggml_nbytes(cache.v); |  | ||||||
|  |         if (n_gpu_layers > (int)n_layer + 1) { | ||||||
|  |             ggml_cuda_assign_buffers_no_scratch(cache.v); | ||||||
|  |             LLAMA_LOG_INFO("%s: offloading v cache to GPU\n", __func__); | ||||||
|  |             vram_kv_cache += ggml_nbytes(cache.v); | ||||||
|  |         } | ||||||
|  |         if (n_gpu_layers > (int)n_layer + 2) { | ||||||
|  |             ggml_cuda_assign_buffers_no_scratch(cache.k); | ||||||
|  |             LLAMA_LOG_INFO("%s: offloading k cache to GPU\n", __func__); | ||||||
|  |             vram_kv_cache += ggml_nbytes(cache.k); | ||||||
|  |         } | ||||||
|  |         if (vram_kv_cache > 0) { | ||||||
|  |             LLAMA_LOG_INFO("%s: VRAM kv self = %.2f MB\n", __func__, vram_kv_cache / 1024.0 / 1024.0); | ||||||
|  |         } | ||||||
|     } |     } | ||||||
|     if (n_gpu_layers > (int)n_layer + 2) { | #endif | ||||||
|         ggml_cuda_assign_buffers_no_scratch(cache.k); |  | ||||||
|         LLAMA_LOG_INFO("%s: offloading k cache to GPU\n", __func__); |  | ||||||
|         vram_kv_cache += ggml_nbytes(cache.k); |  | ||||||
|     } |  | ||||||
|     if (vram_kv_cache > 0) { |  | ||||||
|         LLAMA_LOG_INFO("%s: VRAM kv self = %.2f MB\n", __func__, vram_kv_cache / 1024.0 / 1024.0); |  | ||||||
|     } |  | ||||||
| #endif // GGML_USE_CUBLAS |  | ||||||
|  |  | ||||||
|     return true; |     return true; | ||||||
| } | } | ||||||
| @@ -2521,18 +2548,22 @@ static void llm_load_tensors( | |||||||
|     } |     } | ||||||
|  |  | ||||||
|     (void) main_gpu; |     (void) main_gpu; | ||||||
|  |  | ||||||
|  |     enum ggml_backend_type llama_backend_offload = GGML_BACKEND_CPU; | ||||||
|  |     enum ggml_backend_type llama_backend_offload_split = GGML_BACKEND_CPU; | ||||||
|  |  | ||||||
| #ifdef GGML_USE_CUBLAS | #ifdef GGML_USE_CUBLAS | ||||||
|     LLAMA_LOG_INFO("%s: using " GGML_CUDA_NAME " for GPU acceleration\n", __func__); |     if (ggml_cublas_loaded()) { | ||||||
|     ggml_cuda_set_main_device(main_gpu); |         LLAMA_LOG_INFO("%s: using " GGML_CUDA_NAME " for GPU acceleration\n", __func__); | ||||||
| #define LLAMA_BACKEND_OFFLOAD       GGML_BACKEND_GPU |         ggml_cuda_set_main_device(main_gpu); | ||||||
| #define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU_SPLIT |  | ||||||
|  |         llama_backend_offload = GGML_BACKEND_GPU; | ||||||
|  |         llama_backend_offload_split = GGML_BACKEND_GPU_SPLIT; | ||||||
|  |     } | ||||||
| #elif defined(GGML_USE_CLBLAST) | #elif defined(GGML_USE_CLBLAST) | ||||||
|     LLAMA_LOG_INFO("%s: using OpenCL for GPU acceleration\n", __func__); |         LLAMA_LOG_INFO("%s: using OpenCL for GPU acceleration\n", __func__); | ||||||
| #define LLAMA_BACKEND_OFFLOAD       GGML_BACKEND_GPU |         llama_backend_offload = GGML_BACKEND_GPU; | ||||||
| #define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU |         llama_backend_offload_split = GGML_BACKEND_GPU; | ||||||
| #else |  | ||||||
| #define LLAMA_BACKEND_OFFLOAD       GGML_BACKEND_CPU |  | ||||||
| #define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_CPU |  | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
|     // prepare memory for the weights |     // prepare memory for the weights | ||||||
| @@ -2559,12 +2590,12 @@ static void llm_load_tensors( | |||||||
|                             // norm is not performance relevant on its own but keeping it in VRAM reduces data copying |                             // norm is not performance relevant on its own but keeping it in VRAM reduces data copying | ||||||
|                             // on Windows however this is detrimental unless everything is on the GPU |                             // on Windows however this is detrimental unless everything is on the GPU | ||||||
| #ifndef _WIN32 | #ifndef _WIN32 | ||||||
|                             backend_norm = LLAMA_BACKEND_OFFLOAD; |                             backend_norm = llama_backend_offload; | ||||||
| #else | #else | ||||||
|                             backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; |                             backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload; | ||||||
| #endif // _WIN32 | #endif // _WIN32 | ||||||
|  |  | ||||||
|                             backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; |                             backend_output = llama_backend_offload_split; | ||||||
|                         } else { |                         } else { | ||||||
|                             backend_norm   = GGML_BACKEND_CPU; |                             backend_norm   = GGML_BACKEND_CPU; | ||||||
|                             backend_output = GGML_BACKEND_CPU; |                             backend_output = GGML_BACKEND_CPU; | ||||||
| @@ -2588,8 +2619,8 @@ static void llm_load_tensors( | |||||||
|                     model.layers.resize(n_layer); |                     model.layers.resize(n_layer); | ||||||
|  |  | ||||||
|                     for (uint32_t i = 0; i < n_layer; ++i) { |                     for (uint32_t i = 0; i < n_layer; ++i) { | ||||||
|                         const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT |                         const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT | ||||||
|                         const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT |                         const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT | ||||||
|  |  | ||||||
|                         auto & layer = model.layers[i]; |                         auto & layer = model.layers[i]; | ||||||
|  |  | ||||||
| @@ -2625,12 +2656,12 @@ static void llm_load_tensors( | |||||||
|                             // norm is not performance relevant on its own but keeping it in VRAM reduces data copying |                             // norm is not performance relevant on its own but keeping it in VRAM reduces data copying | ||||||
|                             // on Windows however this is detrimental unless everything is on the GPU |                             // on Windows however this is detrimental unless everything is on the GPU | ||||||
| #ifndef _WIN32 | #ifndef _WIN32 | ||||||
|                             backend_norm = LLAMA_BACKEND_OFFLOAD; |                             backend_norm = llama_backend_offload; | ||||||
| #else | #else | ||||||
|                             backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; |                             backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload; | ||||||
| #endif // _WIN32 | #endif // _WIN32 | ||||||
|  |  | ||||||
|                             backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; |                             backend_output = llama_backend_offload_split; | ||||||
|                         } else { |                         } else { | ||||||
|                             backend_norm   = GGML_BACKEND_CPU; |                             backend_norm   = GGML_BACKEND_CPU; | ||||||
|                             backend_output = GGML_BACKEND_CPU; |                             backend_output = GGML_BACKEND_CPU; | ||||||
| @@ -2654,8 +2685,8 @@ static void llm_load_tensors( | |||||||
|                     model.layers.resize(n_layer); |                     model.layers.resize(n_layer); | ||||||
|  |  | ||||||
|                     for (uint32_t i = 0; i < n_layer; ++i) { |                     for (uint32_t i = 0; i < n_layer; ++i) { | ||||||
|                         const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT |                         const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT | ||||||
|                         const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT |                         const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT | ||||||
|  |  | ||||||
|                         auto & layer = model.layers[i]; |                         auto & layer = model.layers[i]; | ||||||
|  |  | ||||||
| @@ -2695,12 +2726,12 @@ static void llm_load_tensors( | |||||||
|                             // norm is not performance relevant on its own but keeping it in VRAM reduces data copying |                             // norm is not performance relevant on its own but keeping it in VRAM reduces data copying | ||||||
|                             // on Windows however this is detrimental unless everything is on the GPU |                             // on Windows however this is detrimental unless everything is on the GPU | ||||||
| #ifndef _WIN32 | #ifndef _WIN32 | ||||||
|                             backend_norm = LLAMA_BACKEND_OFFLOAD; |                             backend_norm = llama_backend_offload; | ||||||
| #else | #else | ||||||
|                             backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; |                             backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload; | ||||||
| #endif // _WIN32 | #endif // _WIN32 | ||||||
|  |  | ||||||
|                             backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; |                             backend_output = llama_backend_offload_split; | ||||||
|                         } else { |                         } else { | ||||||
|                             backend_norm   = GGML_BACKEND_CPU; |                             backend_norm   = GGML_BACKEND_CPU; | ||||||
|                             backend_output = GGML_BACKEND_CPU; |                             backend_output = GGML_BACKEND_CPU; | ||||||
| @@ -2726,8 +2757,8 @@ static void llm_load_tensors( | |||||||
|                     model.layers.resize(n_layer); |                     model.layers.resize(n_layer); | ||||||
|  |  | ||||||
|                     for (uint32_t i = 0; i < n_layer; ++i) { |                     for (uint32_t i = 0; i < n_layer; ++i) { | ||||||
|                         const ggml_backend_type backend       = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT |                         const ggml_backend_type backend       = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT | ||||||
|                         const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT |                         const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT | ||||||
|  |  | ||||||
|                         auto & layer = model.layers[i]; |                         auto & layer = model.layers[i]; | ||||||
|  |  | ||||||
| @@ -2772,12 +2803,12 @@ static void llm_load_tensors( | |||||||
|                             // norm is not performance relevant on its own but keeping it in VRAM reduces data copying |                             // norm is not performance relevant on its own but keeping it in VRAM reduces data copying | ||||||
|                             // on Windows however this is detrimental unless everything is on the GPU |                             // on Windows however this is detrimental unless everything is on the GPU | ||||||
| #ifndef _WIN32 | #ifndef _WIN32 | ||||||
|                             backend_norm = LLAMA_BACKEND_OFFLOAD; |                             backend_norm = llama_backend_offload; | ||||||
| #else | #else | ||||||
|                             backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; |                             backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload; | ||||||
| #endif // _WIN32 | #endif // _WIN32 | ||||||
|  |  | ||||||
|                             backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; |                             backend_output = llama_backend_offload_split; | ||||||
|                         } else { |                         } else { | ||||||
|                             backend_norm   = GGML_BACKEND_CPU; |                             backend_norm   = GGML_BACKEND_CPU; | ||||||
|                             backend_output = GGML_BACKEND_CPU; |                             backend_output = GGML_BACKEND_CPU; | ||||||
| @@ -2803,8 +2834,8 @@ static void llm_load_tensors( | |||||||
|                     model.layers.resize(n_layer); |                     model.layers.resize(n_layer); | ||||||
|  |  | ||||||
|                     for (uint32_t i = 0; i < n_layer; ++i) { |                     for (uint32_t i = 0; i < n_layer; ++i) { | ||||||
|                         const ggml_backend_type backend       = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT |                         const ggml_backend_type backend       = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT | ||||||
|                         const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT |                         const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT | ||||||
|  |  | ||||||
|                         auto & layer = model.layers[i]; |                         auto & layer = model.layers[i]; | ||||||
|  |  | ||||||
| @@ -2849,12 +2880,12 @@ static void llm_load_tensors( | |||||||
|                             // norm is not performance relevant on its own but keeping it in VRAM reduces data copying |                             // norm is not performance relevant on its own but keeping it in VRAM reduces data copying | ||||||
|                             // on Windows however this is detrimental unless everything is on the GPU |                             // on Windows however this is detrimental unless everything is on the GPU | ||||||
| #ifndef _WIN32 | #ifndef _WIN32 | ||||||
|                             backend_norm = LLAMA_BACKEND_OFFLOAD; |                             backend_norm = llama_backend_offload; | ||||||
| #else | #else | ||||||
|                             backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; |                             backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload; | ||||||
| #endif // _WIN32 | #endif // _WIN32 | ||||||
|  |  | ||||||
|                             backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; |                             backend_output = llama_backend_offload_split; | ||||||
|                         } else { |                         } else { | ||||||
|                             backend_norm   = GGML_BACKEND_CPU; |                             backend_norm   = GGML_BACKEND_CPU; | ||||||
|                             backend_output = GGML_BACKEND_CPU; |                             backend_output = GGML_BACKEND_CPU; | ||||||
| @@ -2877,8 +2908,8 @@ static void llm_load_tensors( | |||||||
|                     const int i_gpu_start = n_layer - n_gpu_layers; |                     const int i_gpu_start = n_layer - n_gpu_layers; | ||||||
|                     model.layers.resize(n_layer); |                     model.layers.resize(n_layer); | ||||||
|                     for (uint32_t i = 0; i < n_layer; ++i) { |                     for (uint32_t i = 0; i < n_layer; ++i) { | ||||||
|                         const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; |                         const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; | ||||||
|                         const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; |                         const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; | ||||||
|                         auto & layer = model.layers[i]; |                         auto & layer = model.layers[i]; | ||||||
|                         layer.attn_norm     = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM,   "weight", i), {n_embd}, backend); |                         layer.attn_norm     = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM,   "weight", i), {n_embd}, backend); | ||||||
|                         layer.attn_norm_b   = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM,   "bias",   i), {n_embd}, backend); |                         layer.attn_norm_b   = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM,   "bias",   i), {n_embd}, backend); | ||||||
| @@ -2915,12 +2946,12 @@ static void llm_load_tensors( | |||||||
|                             // norm is not performance relevant on its own but keeping it in VRAM reduces data copying |                             // norm is not performance relevant on its own but keeping it in VRAM reduces data copying | ||||||
|                             // on Windows however this is detrimental unless everything is on the GPU |                             // on Windows however this is detrimental unless everything is on the GPU | ||||||
| #ifndef _WIN32 | #ifndef _WIN32 | ||||||
|                             backend_norm = LLAMA_BACKEND_OFFLOAD; |                             backend_norm = llama_backend_offload; | ||||||
| #else | #else | ||||||
|                             backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; |                             backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload; | ||||||
| #endif // _WIN32 | #endif // _WIN32 | ||||||
|  |  | ||||||
|                             backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; |                             backend_output = llama_backend_offload_split; | ||||||
|                         } else { |                         } else { | ||||||
|                             backend_norm   = GGML_BACKEND_CPU; |                             backend_norm   = GGML_BACKEND_CPU; | ||||||
|                             backend_output = GGML_BACKEND_CPU; |                             backend_output = GGML_BACKEND_CPU; | ||||||
| @@ -2946,8 +2977,8 @@ static void llm_load_tensors( | |||||||
|                     model.layers.resize(n_layer); |                     model.layers.resize(n_layer); | ||||||
|  |  | ||||||
|                     for (uint32_t i = 0; i < n_layer; ++i) { |                     for (uint32_t i = 0; i < n_layer; ++i) { | ||||||
|                         const ggml_backend_type backend       = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT |                         const ggml_backend_type backend       = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT | ||||||
|                         const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT |                         const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT | ||||||
|  |  | ||||||
|                         auto & layer = model.layers[i]; |                         auto & layer = model.layers[i]; | ||||||
|  |  | ||||||
| @@ -2993,12 +3024,12 @@ static void llm_load_tensors( | |||||||
|                             // norm is not performance relevant on its own but keeping it in VRAM reduces data copying |                             // norm is not performance relevant on its own but keeping it in VRAM reduces data copying | ||||||
|                             // on Windows however this is detrimental unless everything is on the GPU |                             // on Windows however this is detrimental unless everything is on the GPU | ||||||
| #ifndef _WIN32 | #ifndef _WIN32 | ||||||
|                             backend_norm = LLAMA_BACKEND_OFFLOAD; |                             backend_norm = llama_backend_offload; | ||||||
| #else | #else | ||||||
|                             backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; |                             backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload; | ||||||
| #endif // _WIN32 | #endif // _WIN32 | ||||||
|  |  | ||||||
|                             backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; |                             backend_output = llama_backend_offload_split; | ||||||
|                         } else { |                         } else { | ||||||
|                             backend_norm   = GGML_BACKEND_CPU; |                             backend_norm   = GGML_BACKEND_CPU; | ||||||
|                             backend_output = GGML_BACKEND_CPU; |                             backend_output = GGML_BACKEND_CPU; | ||||||
| @@ -3022,8 +3053,8 @@ static void llm_load_tensors( | |||||||
|                     model.layers.resize(n_layer); |                     model.layers.resize(n_layer); | ||||||
|  |  | ||||||
|                     for (uint32_t i = 0; i < n_layer; ++i) { |                     for (uint32_t i = 0; i < n_layer; ++i) { | ||||||
|                         const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT |                         const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT | ||||||
|                         const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT |                         const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT | ||||||
|  |  | ||||||
|                         auto & layer = model.layers[i]; |                         auto & layer = model.layers[i]; | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Meng Zhang
					Meng Zhang