mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	Add ReLU and SQR CUDA ops to (partially) fix Persimmon offloading (#4041)
* Add ReLU and SQR CUDA ops to fix Persimmon offloading * Persimmon loader: More helpful error on CUDA/ROCM when offloading too many layers
This commit is contained in:
		
							
								
								
									
										72
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							
							
						
						
									
										72
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							| @@ -433,6 +433,8 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ | ||||
| #define CUDA_MUL_BLOCK_SIZE 256 | ||||
| #define CUDA_GELU_BLOCK_SIZE 256 | ||||
| #define CUDA_SILU_BLOCK_SIZE 256 | ||||
| #define CUDA_RELU_BLOCK_SIZE 256 | ||||
| #define CUDA_SQR_BLOCK_SIZE 256 | ||||
| #define CUDA_CPY_BLOCK_SIZE 32 | ||||
| #define CUDA_SCALE_BLOCK_SIZE 256 | ||||
| #define CUDA_CLAMP_BLOCK_SIZE 256 | ||||
| @@ -553,6 +555,24 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) { | ||||
|     dst[i] = x[i] / (1.0f + expf(-x[i])); | ||||
| } | ||||
|  | ||||
| static __global__ void relu_f32(const float * x, float * dst, const int k) { | ||||
|     const int i = blockDim.x*blockIdx.x + threadIdx.x; | ||||
|  | ||||
|     if (i >= k) { | ||||
|         return; | ||||
|     } | ||||
|     dst[i] = fmaxf(x[i], 0); | ||||
| } | ||||
|  | ||||
| static __global__ void sqr_f32(const float * x, float * dst, const int k) { | ||||
|     const int i = blockDim.x*blockIdx.x + threadIdx.x; | ||||
|  | ||||
|     if (i >= k) { | ||||
|         return; | ||||
|     } | ||||
|     dst[i] = x[i] * x[i]; | ||||
| } | ||||
|  | ||||
| static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { | ||||
| #pragma unroll | ||||
|     for (int mask = 16; mask > 0; mask >>= 1) { | ||||
| @@ -4759,6 +4779,16 @@ static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_ | ||||
|     silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k); | ||||
| } | ||||
|  | ||||
| static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { | ||||
|     const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE; | ||||
|     relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k); | ||||
| } | ||||
|  | ||||
| static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { | ||||
|     const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE; | ||||
|     sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k); | ||||
| } | ||||
|  | ||||
| static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { | ||||
|     GGML_ASSERT(ncols % WARP_SIZE == 0); | ||||
|     if (ncols < 1024) { | ||||
| @@ -6128,6 +6158,34 @@ inline void ggml_cuda_op_silu( | ||||
|     (void) src1_dd; | ||||
| } | ||||
|  | ||||
| inline void ggml_cuda_op_relu( | ||||
|     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, | ||||
|     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { | ||||
|  | ||||
|     GGML_ASSERT(src0->type == GGML_TYPE_F32); | ||||
|     GGML_ASSERT( dst->type == GGML_TYPE_F32); | ||||
|  | ||||
|     relu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream); | ||||
|  | ||||
|     (void) src1; | ||||
|     (void) dst; | ||||
|     (void) src1_dd; | ||||
| } | ||||
|  | ||||
| inline void ggml_cuda_op_sqr( | ||||
|     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, | ||||
|     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { | ||||
|  | ||||
|     GGML_ASSERT(src0->type == GGML_TYPE_F32); | ||||
|     GGML_ASSERT( dst->type == GGML_TYPE_F32); | ||||
|  | ||||
|     sqr_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream); | ||||
|  | ||||
|     (void) src1; | ||||
|     (void) dst; | ||||
|     (void) src1_dd; | ||||
| } | ||||
|  | ||||
| inline void ggml_cuda_op_norm( | ||||
|     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, | ||||
|     const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { | ||||
| @@ -7160,6 +7218,14 @@ static void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, g | ||||
|     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu); | ||||
| } | ||||
|  | ||||
| static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | ||||
|     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu); | ||||
| } | ||||
|  | ||||
| static void ggml_cuda_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | ||||
|     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sqr); | ||||
| } | ||||
|  | ||||
| static void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | ||||
|     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm); | ||||
| } | ||||
| @@ -7891,6 +7957,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ | ||||
|                 case GGML_UNARY_OP_SILU: | ||||
|                     func = ggml_cuda_silu; | ||||
|                     break; | ||||
|                 case GGML_UNARY_OP_RELU: | ||||
|                     func = ggml_cuda_relu; | ||||
|                     break; | ||||
|                 default: | ||||
|                     return false; | ||||
|             } break; | ||||
| @@ -7909,6 +7978,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ | ||||
|         case GGML_OP_SCALE: | ||||
|             func = ggml_cuda_scale; | ||||
|             break; | ||||
|         case GGML_OP_SQR: | ||||
|             func = ggml_cuda_sqr; | ||||
|             break; | ||||
|         case GGML_OP_CLAMP: | ||||
|             if (!any_on_device) { | ||||
|                 return false; | ||||
|   | ||||
| @@ -2877,6 +2877,13 @@ static void llm_load_tensors( | ||||
|                         ggml_backend_type backend_output; | ||||
|  | ||||
|                         if (n_gpu_layers > int(n_layer)) { | ||||
| #ifdef GGML_USE_CUBLAS | ||||
|                             if (n_gpu_layers > int(n_layer + 1)) { | ||||
|                                 LLAMA_LOG_ERROR("%s: CUDA backend missing Persimmon CUDA ops, can offload at most %ld layers. See: https://github.com/ggerganov/llama.cpp/issues/4038\n", | ||||
|                                     __func__, n_layer + 1); | ||||
|                                 throw std::runtime_error("Persimmon CUDA offload failed"); | ||||
|                             } | ||||
| #endif | ||||
|                             // norm is not performance relevant on its own but keeping it in VRAM reduces data copying | ||||
|                             // on Windows however this is detrimental unless everything is on the GPU | ||||
| #ifndef _WIN32 | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Kerfuffle
					Kerfuffle