mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-30 08:42:00 +00:00 
			
		
		
		
	CUDA: add roll (#14919)
* CUDA: add roll * Make everything const, use __restrict__
This commit is contained in:
		| @@ -31,6 +31,7 @@ | |||||||
| #include "ggml-cuda/pool2d.cuh" | #include "ggml-cuda/pool2d.cuh" | ||||||
| #include "ggml-cuda/quantize.cuh" | #include "ggml-cuda/quantize.cuh" | ||||||
| #include "ggml-cuda/rope.cuh" | #include "ggml-cuda/rope.cuh" | ||||||
|  | #include "ggml-cuda/roll.cuh" | ||||||
| #include "ggml-cuda/scale.cuh" | #include "ggml-cuda/scale.cuh" | ||||||
| #include "ggml-cuda/softmax.cuh" | #include "ggml-cuda/softmax.cuh" | ||||||
| #include "ggml-cuda/ssm-conv.cuh" | #include "ggml-cuda/ssm-conv.cuh" | ||||||
| @@ -2419,6 +2420,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg | |||||||
|         case GGML_OP_ROPE_BACK: |         case GGML_OP_ROPE_BACK: | ||||||
|             ggml_cuda_op_rope_back(ctx, dst); |             ggml_cuda_op_rope_back(ctx, dst); | ||||||
|             break; |             break; | ||||||
|  |         case GGML_OP_ROLL: | ||||||
|  |             ggml_cuda_op_roll(ctx, dst); | ||||||
|  |             break; | ||||||
|         case GGML_OP_IM2COL: |         case GGML_OP_IM2COL: | ||||||
|             ggml_cuda_op_im2col(ctx, dst); |             ggml_cuda_op_im2col(ctx, dst); | ||||||
|             break; |             break; | ||||||
| @@ -3411,6 +3415,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g | |||||||
|             memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float)); |             memcpy(&max_bias, (const float *) op->op_params + 1, sizeof(float)); | ||||||
|             return max_bias == 0.0f; |             return max_bias == 0.0f; | ||||||
|         } |         } | ||||||
|  |         case GGML_OP_ROLL: | ||||||
|  |             if(op->src[0]->type == GGML_TYPE_F32) { | ||||||
|  |                 return true; | ||||||
|  |             } | ||||||
|  |             return false; | ||||||
|         case GGML_OP_ROPE: |         case GGML_OP_ROPE: | ||||||
|         case GGML_OP_ROPE_BACK: { |         case GGML_OP_ROPE_BACK: { | ||||||
|             return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]); |             return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]); | ||||||
|   | |||||||
							
								
								
									
										67
									
								
								ggml/src/ggml-cuda/roll.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								ggml/src/ggml-cuda/roll.cu
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,67 @@ | |||||||
|  | #include "ggml-cuda/common.cuh" | ||||||
|  | #include "roll.cuh" | ||||||
|  |  | ||||||
|  | static __forceinline__ __device__ int64_t wrap_index(const int64_t idx, const int64_t ne) { | ||||||
|  |     if (idx < 0) { | ||||||
|  |         return idx + ne; | ||||||
|  |     } | ||||||
|  |     if (idx >= ne) { | ||||||
|  |         return idx - ne; | ||||||
|  |     } | ||||||
|  |     return idx; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | static __global__ void roll_f32_cuda(const float * __restrict__ src, | ||||||
|  |                                      float * __restrict__ dst, | ||||||
|  |                                      const int64_t ne00, | ||||||
|  |                                      const int64_t ne01, | ||||||
|  |                                      const int64_t ne02, | ||||||
|  |                                      const int64_t ne03, | ||||||
|  |                                      const int     s0, | ||||||
|  |                                      const int     s1, | ||||||
|  |                                      const int     s2, | ||||||
|  |                                      const int     s3) { | ||||||
|  |     const int64_t idx        = int64_t(blockDim.x) * blockIdx.x + threadIdx.x; | ||||||
|  |     const int64_t n_elements = ne00 * ne01 * ne02 * ne03; | ||||||
|  |  | ||||||
|  |     if (idx >= n_elements) { | ||||||
|  |         return; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     const int64_t i0 = idx % ne00; | ||||||
|  |     const int64_t i1 = (idx / ne00) % ne01; | ||||||
|  |     const int64_t i2 = (idx / (ne00 * ne01)) % ne02; | ||||||
|  |     const int64_t i3 = (idx / (ne00 * ne01 * ne02)) % ne03; | ||||||
|  |  | ||||||
|  |     const int64_t d0 = wrap_index(i0 - s0, ne00); | ||||||
|  |     const int64_t d1 = wrap_index(i1 - s1, ne01); | ||||||
|  |     const int64_t d2 = wrap_index(i2 - s2, ne02); | ||||||
|  |     const int64_t d3 = wrap_index(i3 - s3, ne03); | ||||||
|  |  | ||||||
|  |     dst[i3 * (ne00 * ne01 * ne02) + i2 * (ne01 * ne00) + i1 * ne00 + i0] = | ||||||
|  |         src[d3 * (ne00 * ne01 * ne02) + d2 * (ne01 * ne00) + d1 * ne00 + d0]; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void ggml_cuda_op_roll(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||||||
|  |     int s0 = dst->op_params[0]; | ||||||
|  |     int s1 = dst->op_params[1]; | ||||||
|  |     int s2 = dst->op_params[2]; | ||||||
|  |     int s3 = dst->op_params[3]; | ||||||
|  |  | ||||||
|  |     const ggml_tensor * src0   = dst->src[0]; | ||||||
|  |     const float *       src0_d = (const float *) dst->src[0]->data; | ||||||
|  |     float *             dst_d  = (float *) dst->data; | ||||||
|  |  | ||||||
|  |     GGML_TENSOR_UNARY_OP_LOCALS; | ||||||
|  |  | ||||||
|  |     GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32); | ||||||
|  |     GGML_ASSERT(ggml_are_same_shape(dst->src[0], dst)); | ||||||
|  |  | ||||||
|  |     cudaStream_t stream = ctx.stream(); | ||||||
|  |  | ||||||
|  |     int64_t sz         = (ne00 * ne01 * ne02 * ne03); | ||||||
|  |     int64_t num_blocks = (sz + CUDA_ROLL_BLOCK_SIZE - 1) / CUDA_ROLL_BLOCK_SIZE; | ||||||
|  |  | ||||||
|  |     roll_f32_cuda<<<num_blocks, CUDA_ROLL_BLOCK_SIZE, 0, stream>>>( | ||||||
|  |         src0_d, dst_d, ne00, ne01, ne02, ne03, s0, s1, s2, s3); | ||||||
|  | } | ||||||
							
								
								
									
										5
									
								
								ggml/src/ggml-cuda/roll.cuh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								ggml/src/ggml-cuda/roll.cuh
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | |||||||
|  | #include "common.cuh" | ||||||
|  |  | ||||||
|  | #define CUDA_ROLL_BLOCK_SIZE 256 | ||||||
|  |  | ||||||
|  | void ggml_cuda_op_roll(ggml_backend_cuda_context & ctx, ggml_tensor * dst); | ||||||
		Reference in New Issue
	
	Block a user
	 Aman Gupta
					Aman Gupta