mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	CUDA: Implemented row flattening for non-glm RoPE (#2468)
This commit is contained in:
		
							
								
								
									
										23
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							
							
						
						
									
										23
									
								
								ggml-cuda.cu
									
									
									
									
									
								
							| @@ -3150,7 +3150,8 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, | |||||||
| } | } | ||||||
|  |  | ||||||
| // rope == RoPE == rotary positional embedding | // rope == RoPE == rotary positional embedding | ||||||
| static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p, const float theta_scale) { | static __global__ void rope_f32(const float * x, float * dst, const int ncols, const float p0, | ||||||
|  |                                 const float p_delta, const int p_delta_rows, const float theta_scale) { | ||||||
|     const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x); |     const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x); | ||||||
|  |  | ||||||
|     if (col >= ncols) { |     if (col >= ncols) { | ||||||
| @@ -3160,7 +3161,7 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c | |||||||
|     const int row = blockDim.y*blockIdx.y + threadIdx.y; |     const int row = blockDim.y*blockIdx.y + threadIdx.y; | ||||||
|     const int i = row*ncols + col; |     const int i = row*ncols + col; | ||||||
|  |  | ||||||
|     const float theta = p*powf(theta_scale, col/2); |     const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2); | ||||||
|     const float sin_theta = sinf(theta); |     const float sin_theta = sinf(theta); | ||||||
|     const float cos_theta = cosf(theta); |     const float cos_theta = cosf(theta); | ||||||
|  |  | ||||||
| @@ -3764,12 +3765,13 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons | |||||||
|     scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k); |     scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k); | ||||||
| } | } | ||||||
|  |  | ||||||
| static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float theta_scale, cudaStream_t stream) { | static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0, | ||||||
|  |                           const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) { | ||||||
|     GGML_ASSERT(nrows % 2 == 0); |     GGML_ASSERT(nrows % 2 == 0); | ||||||
|     const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1); |     const dim3 block_dims(2*CUDA_ROPE_BLOCK_SIZE, 1, 1); | ||||||
|     const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); |     const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); | ||||||
|     const dim3 block_nums(num_blocks_x, nrows, 1); |     const dim3 block_nums(num_blocks_x, nrows, 1); | ||||||
|     rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, theta_scale); |     rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale); | ||||||
| } | } | ||||||
|  |  | ||||||
| static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) { | static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) { | ||||||
| @@ -4465,6 +4467,7 @@ inline void ggml_cuda_op_rope( | |||||||
|     GGML_ASSERT(dst_ddf_i != nullptr); |     GGML_ASSERT(dst_ddf_i != nullptr); | ||||||
|  |  | ||||||
|     const int64_t ne00 = src0->ne[0]; |     const int64_t ne00 = src0->ne[0]; | ||||||
|  |     const int64_t ne01 = src0->ne[1]; | ||||||
|     const int64_t i01_diff = i01_high - i01_low; |     const int64_t i01_diff = i01_high - i01_low; | ||||||
|  |  | ||||||
|     const int n_past = ((int32_t *) dst->op_params)[0]; |     const int n_past = ((int32_t *) dst->op_params)[0]; | ||||||
| @@ -4478,17 +4481,18 @@ inline void ggml_cuda_op_rope( | |||||||
|     memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); |     memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); | ||||||
|  |  | ||||||
|     const float theta_scale = powf(freq_base, -2.0f/n_dims); |     const float theta_scale = powf(freq_base, -2.0f/n_dims); | ||||||
|     const float p = (((mode & 1) == 0 ? n_past + i02 : i02)) * freq_scale; |  | ||||||
|  |  | ||||||
|     bool is_glm = mode & 4; |     const bool is_glm = mode & 4; | ||||||
|  |  | ||||||
|     // compute |     // compute | ||||||
|     if (is_glm) { |     if (is_glm) { | ||||||
|  |         const float p = (((mode & 1) == 0 ? n_past + i02 : i02)) * freq_scale; | ||||||
|         const float id_p = min(p, n_ctx - 2.f); |         const float id_p = min(p, n_ctx - 2.f); | ||||||
|         const float block_p = max(p - (n_ctx - 2.f), 0.f); |         const float block_p = max(p - (n_ctx - 2.f), 0.f); | ||||||
|         rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main); |         rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main); | ||||||
|     } else { |     } else { | ||||||
|         rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main); |         const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale; | ||||||
|  |         rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main); | ||||||
|     } |     } | ||||||
|  |  | ||||||
|     (void) src1; |     (void) src1; | ||||||
| @@ -5103,7 +5107,10 @@ void ggml_cuda_soft_max(const ggml_tensor * src0, const ggml_tensor * src1, ggml | |||||||
|  |  | ||||||
| void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | ||||||
|     GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); |     GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); | ||||||
|     ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, false); // FIXME flatten changes results |  | ||||||
|  |     const int mode = ((int32_t *) dst->op_params)[2]; | ||||||
|  |     const bool is_glm = mode & 4; | ||||||
|  |     ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, !is_glm); // flatten support not implemented for glm | ||||||
| } | } | ||||||
|  |  | ||||||
| void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Johannes Gäßler
					Johannes Gäßler