mirror of
				https://github.com/ggml-org/llama.cpp.git
				synced 2025-10-31 08:51:55 +00:00 
			
		
		
		
	CUDA/HIP: Fix fattn-vec-* when device warp size is not 32 (#12315)
When fattn-wmma was ported over to warp64 various bits that also touch fattn-vec where converted to selectable warp size, however the fattn-vec kernels dont work with 64 wide warps for now, so we need to avoid launching them with parameters for warp64
This commit is contained in:
		| @@ -52,12 +52,11 @@ typedef half (*vec_dot_KQ_f16_t)( | |||||||
| typedef float (*vec_dot_KQ_f32_t)( | typedef float (*vec_dot_KQ_f32_t)( | ||||||
|     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); |     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); | ||||||
|  |  | ||||||
| template<typename T, int D> | template<typename T, int D, int warp_size> | ||||||
| static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( | static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( | ||||||
|     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { |     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { | ||||||
|  |  | ||||||
|     const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c; |     const block_q4_0 * K_q4_0 = (const block_q4_0 *) K_c; | ||||||
|     constexpr int warp_size = ggml_cuda_get_physical_warp_size(); |  | ||||||
|     GGML_UNUSED(Q_v); |     GGML_UNUSED(Q_v); | ||||||
|  |  | ||||||
|     T sum = 0.0f; |     T sum = 0.0f; | ||||||
| @@ -93,12 +92,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_0( | |||||||
|     return sum; |     return sum; | ||||||
| } | } | ||||||
|  |  | ||||||
| template<typename T, int D> | template<typename T, int D, int warp_size> | ||||||
| static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( | static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( | ||||||
|     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { |     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { | ||||||
|  |  | ||||||
|     const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c; |     const block_q4_1 * K_q4_1 = (const block_q4_1 *) K_c; | ||||||
|     constexpr int warp_size = ggml_cuda_get_physical_warp_size(); |  | ||||||
|     GGML_UNUSED(Q_v); |     GGML_UNUSED(Q_v); | ||||||
|  |  | ||||||
|     T sum = 0.0f; |     T sum = 0.0f; | ||||||
| @@ -138,12 +136,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q4_1( | |||||||
|     return sum; |     return sum; | ||||||
| } | } | ||||||
|  |  | ||||||
| template<typename T, int D> | template<typename T, int D, int warp_size> | ||||||
| static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( | static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( | ||||||
|     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { |     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { | ||||||
|  |  | ||||||
|     const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c; |     const block_q5_0 * K_q5_0 = (const block_q5_0 *) K_c; | ||||||
|     constexpr int warp_size = ggml_cuda_get_physical_warp_size(); |  | ||||||
|     GGML_UNUSED(Q_v); |     GGML_UNUSED(Q_v); | ||||||
|  |  | ||||||
|     T sum = 0.0f; |     T sum = 0.0f; | ||||||
| @@ -186,12 +183,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_0( | |||||||
|     return sum; |     return sum; | ||||||
| } | } | ||||||
|  |  | ||||||
| template<typename T, int D> | template<typename T, int D, int warp_size> | ||||||
| static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( | static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( | ||||||
|     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { |     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { | ||||||
|  |  | ||||||
|     const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c; |     const block_q5_1 * K_q5_1 = (const block_q5_1 *) K_c; | ||||||
|     constexpr int warp_size = ggml_cuda_get_physical_warp_size(); |  | ||||||
|     GGML_UNUSED(Q_v); |     GGML_UNUSED(Q_v); | ||||||
|  |  | ||||||
|     T sum = 0.0f; |     T sum = 0.0f; | ||||||
| @@ -238,12 +234,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q5_1( | |||||||
|     return sum; |     return sum; | ||||||
| } | } | ||||||
|  |  | ||||||
| template <typename T, int D> | template <typename T, int D, int warp_size> | ||||||
| static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( | static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( | ||||||
|     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { |     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { | ||||||
|  |  | ||||||
|     const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c; |     const block_q8_0 * K_q8_0 = (const block_q8_0 *) K_c; | ||||||
|     constexpr int warp_size = ggml_cuda_get_physical_warp_size(); |  | ||||||
|     GGML_UNUSED(Q_v); |     GGML_UNUSED(Q_v); | ||||||
|  |  | ||||||
|     T sum = 0.0f; |     T sum = 0.0f; | ||||||
| @@ -272,12 +267,11 @@ static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_q8_0( | |||||||
|     return sum; |     return sum; | ||||||
| } | } | ||||||
|  |  | ||||||
| template <typename T, int D> | template <typename T, int D, int warp_size> | ||||||
| static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( | static __device__ __forceinline__ T vec_dot_fattn_vec_KQ_f16( | ||||||
|     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { |     const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds_v) { | ||||||
|  |  | ||||||
|     const half2 * K_h2 = (const half2 *) K_c; |     const half2 * K_h2 = (const half2 *) K_c; | ||||||
|     constexpr int warp_size = ggml_cuda_get_physical_warp_size(); |  | ||||||
|     GGML_UNUSED(Q_q8); |     GGML_UNUSED(Q_q8); | ||||||
|     GGML_UNUSED(Q_ds_v); |     GGML_UNUSED(Q_ds_v); | ||||||
|  |  | ||||||
| @@ -480,25 +474,25 @@ static __device__ __forceinline__ T dequantize_1_f16(const void * __restrict__ v | |||||||
|     return x[i]; |     return x[i]; | ||||||
| } | } | ||||||
|  |  | ||||||
| template <int D> | template <int D, int warp_size = WARP_SIZE> | ||||||
| constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) { | constexpr __device__ vec_dot_KQ_f16_t get_vec_dot_KQ_f16(ggml_type type_K) { | ||||||
|     return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D> : |     return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<half, D, warp_size> : | ||||||
|         type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D> : |         type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<half, D, warp_size> : | ||||||
|         type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D> : |         type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<half, D, warp_size> : | ||||||
|         type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D> : |         type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<half, D, warp_size> : | ||||||
|         type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D> : |         type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<half, D, warp_size> : | ||||||
|         type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D> : |         type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<half, D, warp_size> : | ||||||
|         nullptr; |         nullptr; | ||||||
| } | } | ||||||
|  |  | ||||||
| template <int D> | template <int D, int warp_size = WARP_SIZE> | ||||||
| constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) { | constexpr __device__ vec_dot_KQ_f32_t get_vec_dot_KQ_f32(ggml_type type_K) { | ||||||
|     return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, D> : |     return type_K == GGML_TYPE_Q4_0 ? vec_dot_fattn_vec_KQ_q4_0<float, D, warp_size> : | ||||||
|         type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, D> : |         type_K == GGML_TYPE_Q4_1 ? vec_dot_fattn_vec_KQ_q4_1<float, D, warp_size> : | ||||||
|         type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D> : |         type_K == GGML_TYPE_Q5_0 ? vec_dot_fattn_vec_KQ_q5_0<float, D, warp_size> : | ||||||
|         type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D> : |         type_K == GGML_TYPE_Q5_1 ? vec_dot_fattn_vec_KQ_q5_1<float, D, warp_size> : | ||||||
|         type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D> : |         type_K == GGML_TYPE_Q8_0 ? vec_dot_fattn_vec_KQ_q8_0<float, D, warp_size> : | ||||||
|         type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D> : |         type_K == GGML_TYPE_F16 ? vec_dot_fattn_vec_KQ_f16<float, D, warp_size> : | ||||||
|         nullptr; |         nullptr; | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -681,7 +675,8 @@ static void on_no_fattn_vec_case(const int D) { | |||||||
| template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride> | template <int D, int ncols1, int ncols2, int parallel_blocks, int KQ_stride> | ||||||
| void launch_fattn( | void launch_fattn( | ||||||
|     ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, |     ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, | ||||||
|     const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V |     const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V, | ||||||
|  |     const int warp_size = WARP_SIZE | ||||||
| ) { | ) { | ||||||
|     constexpr int ncols = ncols1 * ncols2; |     constexpr int ncols = ncols1 * ncols2; | ||||||
|  |  | ||||||
| @@ -704,8 +699,6 @@ void launch_fattn( | |||||||
|  |  | ||||||
|     GGML_ASSERT(Q->ne[3] == 1); |     GGML_ASSERT(Q->ne[3] == 1); | ||||||
|  |  | ||||||
|     const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size; |  | ||||||
|  |  | ||||||
|     ggml_cuda_pool & pool = ctx.pool(); |     ggml_cuda_pool & pool = ctx.pool(); | ||||||
|     cudaStream_t main_stream = ctx.stream(); |     cudaStream_t main_stream = ctx.stream(); | ||||||
|     const int id  = ggml_cuda_get_device(); |     const int id  = ggml_cuda_get_device(); | ||||||
| @@ -805,7 +798,6 @@ void launch_fattn( | |||||||
|     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); |     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); | ||||||
|  |  | ||||||
|     GGML_ASSERT(block_dim.x % warp_size == 0); |     GGML_ASSERT(block_dim.x % warp_size == 0); | ||||||
|     GGML_ASSERT(!GGML_CUDA_CC_IS_AMD(cc) || block_dim.x * block_dim.y <= 4 * (unsigned int)warp_size); |  | ||||||
|     fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>( |     fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>( | ||||||
|         (const char *) Q->data, |         (const char *) Q->data, | ||||||
|         K_data, |         K_data, | ||||||
|   | |||||||
| @@ -469,6 +469,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm | |||||||
|     constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16; |     constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16; | ||||||
|     const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; |     const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3]; | ||||||
|     const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; |     const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; | ||||||
|  |     const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; | ||||||
|  |  | ||||||
|     float logit_softcap; |     float logit_softcap; | ||||||
|     memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); |     memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); | ||||||
| @@ -485,7 +486,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm | |||||||
|             fattn_kernel = flash_attn_ext_f16< |             fattn_kernel = flash_attn_ext_f16< | ||||||
|                 D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; |                 D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; | ||||||
|         } |         } | ||||||
|         launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true); |         launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size); | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|     if (2*blocks_num_pb1 < 2*nsm) { |     if (2*blocks_num_pb1 < 2*nsm) { | ||||||
| @@ -500,7 +501,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm | |||||||
|             fattn_kernel = flash_attn_ext_f16< |             fattn_kernel = flash_attn_ext_f16< | ||||||
|                 D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; |                 D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; | ||||||
|         } |         } | ||||||
|         launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true); |         launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size); | ||||||
|         return; |         return; | ||||||
|     } |     } | ||||||
|     constexpr int parallel_blocks = 1; |     constexpr int parallel_blocks = 1; | ||||||
| @@ -514,7 +515,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggm | |||||||
|         fattn_kernel = flash_attn_ext_f16< |         fattn_kernel = flash_attn_ext_f16< | ||||||
|             D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; |             D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>; | ||||||
|     } |     } | ||||||
|     launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true); |     launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true, warp_size); | ||||||
| } | } | ||||||
|  |  | ||||||
| void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 uvos
					uvos